You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

65 lines
1.9 KiB
Python

import os
import sys
import pathlib
__dir__ = pathlib.Path(os.path.abspath(__file__))
sys.path.append(str(__dir__))
sys.path.append(str(__dir__.parent.parent))
import paddle
import paddle.distributed as dist
from utils import Config, ArgsParser
def init_args():
parser = ArgsParser()
args = parser.parse_args()
return args
def main(config, profiler_options):
from models import build_model, build_loss
from data_loader import get_dataloader
from trainer import Trainer
from post_processing import get_post_processing
from utils import get_metric
if paddle.device.cuda.device_count() > 1:
dist.init_parallel_env()
config["distributed"] = True
else:
config["distributed"] = False
train_loader = get_dataloader(config["dataset"]["train"], config["distributed"])
assert train_loader is not None
if "validate" in config["dataset"]:
validate_loader = get_dataloader(config["dataset"]["validate"], False)
else:
validate_loader = None
criterion = build_loss(config["loss"])
config["arch"]["backbone"]["in_channels"] = (
3 if config["dataset"]["train"]["dataset"]["args"]["img_mode"] != "GRAY" else 1
)
model = build_model(config["arch"])
# set @to_static for benchmark, skip this by default.
post_p = get_post_processing(config["post_processing"])
metric = get_metric(config["metric"])
trainer = Trainer(
config=config,
model=model,
criterion=criterion,
train_loader=train_loader,
post_process=post_p,
metric_cls=metric,
validate_loader=validate_loader,
profiler_options=profiler_options,
)
trainer.train()
if __name__ == "__main__":
args = init_args()
assert os.path.exists(args.config_file)
config = Config(args.config_file)
config.merge_dict(args.opt)
main(config.cfg, args.profiler_options)