mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -47,25 +47,27 @@ def get_default_parser():
|
||||
Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, help='path to the config file')
|
||||
parser.add_argument('--host', type=str, help='the master address for distributed training')
|
||||
parser.add_argument('--port', type=int, help='the master port for distributed training')
|
||||
parser.add_argument('--world_size', type=int, help='world size for distributed training')
|
||||
parser.add_argument('--rank', type=int, help='rank for the default process group')
|
||||
parser.add_argument('--local_rank', type=int, help='local rank on the node')
|
||||
parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
|
||||
parser.add_argument("--config", type=str, help="path to the config file")
|
||||
parser.add_argument("--host", type=str, help="the master address for distributed training")
|
||||
parser.add_argument("--port", type=int, help="the master port for distributed training")
|
||||
parser.add_argument("--world_size", type=int, help="world size for distributed training")
|
||||
parser.add_argument("--rank", type=int, help="rank for the default process group")
|
||||
parser.add_argument("--local_rank", type=int, help="local rank on the node")
|
||||
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
|
||||
return parser
|
||||
|
||||
|
||||
def launch(config: Union[str, Path, Config, Dict],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
local_rank: int = None,
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
def launch(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = "nccl",
|
||||
local_rank: int = None,
|
||||
seed: int = 1024,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
|
||||
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
|
||||
|
||||
@@ -88,8 +90,9 @@ def launch(config: Union[str, Path, Config, Dict],
|
||||
gpc.verbose = verbose
|
||||
|
||||
# set config
|
||||
assert isinstance(config, (Config, str, Path, dict)), \
|
||||
f'expected argument config to be Config, str or Path, but got {type(config)}'
|
||||
assert isinstance(
|
||||
config, (Config, str, Path, dict)
|
||||
), f"expected argument config to be Config, str or Path, but got {type(config)}"
|
||||
if not isinstance(config, Config) and isinstance(config, dict):
|
||||
config = Config(config)
|
||||
if isinstance(config, (str, Path)):
|
||||
@@ -115,18 +118,21 @@ def launch(config: Union[str, Path, Config, Dict],
|
||||
if verbose:
|
||||
logger = get_dist_logger()
|
||||
logger.info(
|
||||
f'Distributed environment is initialized, '
|
||||
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
||||
f'tensor parallel size: {gpc.tensor_parallel_size}',
|
||||
ranks=[0])
|
||||
f"Distributed environment is initialized, "
|
||||
f"data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, "
|
||||
f"tensor parallel size: {gpc.tensor_parallel_size}",
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
|
||||
def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
def launch_from_slurm(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = "nccl",
|
||||
seed: int = 1024,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
|
||||
set by SLURM
|
||||
|
||||
@@ -139,29 +145,33 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
||||
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||
"""
|
||||
try:
|
||||
rank = int(os.environ['SLURM_PROCID'])
|
||||
world_size = int(os.environ['SLURM_NPROCS'])
|
||||
rank = int(os.environ["SLURM_PROCID"])
|
||||
world_size = int(os.environ["SLURM_NPROCS"])
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
|
||||
)
|
||||
|
||||
launch(config=config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose)
|
||||
launch(
|
||||
config=config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
def launch_from_openmpi(config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
def launch_from_openmpi(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = "nccl",
|
||||
seed: int = 1024,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
|
||||
set by OpenMPI
|
||||
|
||||
@@ -174,29 +184,30 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
|
||||
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||
"""
|
||||
try:
|
||||
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
||||
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
||||
rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
|
||||
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
|
||||
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
|
||||
)
|
||||
|
||||
launch(config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose)
|
||||
launch(
|
||||
config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
def launch_from_torch(
|
||||
config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True
|
||||
):
|
||||
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||
from the environment variables set by PyTorch
|
||||
|
||||
@@ -207,35 +218,39 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||
"""
|
||||
try:
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
host = os.environ['MASTER_ADDR']
|
||||
port = int(os.environ['MASTER_PORT'])
|
||||
rank = int(os.environ["RANK"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
host = os.environ["MASTER_ADDR"]
|
||||
port = int(os.environ["MASTER_PORT"])
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
|
||||
)
|
||||
|
||||
launch(config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose)
|
||||
launch(
|
||||
config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
def initialize(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Optional[_Loss] = None,
|
||||
train_dataloader: Optional[Iterable] = None,
|
||||
test_dataloader: Optional[Iterable] = None,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
ophooks: Optional[List[BaseOpHook]] = None,
|
||||
verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
||||
def initialize(
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: Optional[_Loss] = None,
|
||||
train_dataloader: Optional[Iterable] = None,
|
||||
test_dataloader: Optional[Iterable] = None,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
ophooks: Optional[List[BaseOpHook]] = None,
|
||||
verbose: bool = True,
|
||||
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
||||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||
loaded into gpc.config.
|
||||
|
||||
@@ -267,30 +282,30 @@ def initialize(model: nn.Module,
|
||||
f"\n========== Your Config ========\n"
|
||||
f"{pprint.pformat(gpc.config)}\n"
|
||||
f"================================\n",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
# cudnn
|
||||
cudnn_benchmark = config.get('cudnn_benchmark', False)
|
||||
cudnn_deterministic = config.get('cudnn_deterministic', False)
|
||||
cudnn_benchmark = config.get("cudnn_benchmark", False)
|
||||
cudnn_deterministic = config.get("cudnn_deterministic", False)
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
torch.backends.cudnn.deterministic = cudnn_deterministic
|
||||
if verbose:
|
||||
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||
|
||||
# zero
|
||||
use_zero = hasattr(gpc.config, 'zero')
|
||||
use_zero = hasattr(gpc.config, "zero")
|
||||
if use_zero:
|
||||
zero_cfg = gpc.config.get('zero', None)
|
||||
zero_cfg = gpc.config.get("zero", None)
|
||||
if zero_cfg is not None:
|
||||
cfg_ = zero_cfg.copy()
|
||||
else:
|
||||
cfg_ = {}
|
||||
optimizer_config = zero_cfg.get('optimizer_config', None)
|
||||
model_config = zero_cfg.get('model_config', None)
|
||||
model, optimizer = convert_to_zero_v2(model,
|
||||
optimizer,
|
||||
model_config=model_config,
|
||||
optimizer_config=optimizer_config)
|
||||
optimizer_config = zero_cfg.get("optimizer_config", None)
|
||||
model_config = zero_cfg.get("model_config", None)
|
||||
model, optimizer = convert_to_zero_v2(
|
||||
model, optimizer, model_config=model_config, optimizer_config=optimizer_config
|
||||
)
|
||||
|
||||
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
|
||||
else:
|
||||
@@ -316,38 +331,38 @@ def initialize(model: nn.Module,
|
||||
logger.warning(
|
||||
"The parameters of models is not automatically synchronized.\n"
|
||||
"Please make sure that all parameters are the same in data parallel group.",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
# check amp and zero
|
||||
fp16_cfg = gpc.config.get('fp16', None)
|
||||
fp16_cfg = gpc.config.get("fp16", None)
|
||||
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
|
||||
raise ConfigException(
|
||||
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
|
||||
"It is not allowed to set fp16 and zero configuration in your config file at the same time"
|
||||
)
|
||||
|
||||
# clip grad norm
|
||||
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
|
||||
clip_grad_norm = gpc.config.get("clip_grad_norm", 0.0)
|
||||
|
||||
# initialize amp
|
||||
amp_mode = None
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None:
|
||||
cfg_ = fp16_cfg.copy()
|
||||
amp_mode = cfg_.pop('mode')
|
||||
amp_mode = cfg_.pop("mode")
|
||||
if is_using_pp():
|
||||
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
|
||||
assert amp_mode == AMP_TYPE.NAIVE, "Pipeline only support NaiveAMP currently"
|
||||
if amp_mode == AMP_TYPE.NAIVE:
|
||||
cfg_['clip_grad_norm'] = clip_grad_norm
|
||||
model, optimizer, criterion = convert_to_amp(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
mode=amp_mode,
|
||||
amp_config=cfg_)
|
||||
cfg_["clip_grad_norm"] = clip_grad_norm
|
||||
model, optimizer, criterion = convert_to_amp(
|
||||
model=model, optimizer=optimizer, criterion=criterion, mode=amp_mode, amp_config=cfg_
|
||||
)
|
||||
|
||||
# get torch ddp config
|
||||
torch_ddp_cfg = gpc.config.get('torch_ddp', dict())
|
||||
torch_ddp_cfg = gpc.config.get("torch_ddp", dict())
|
||||
|
||||
# gradient handler
|
||||
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||
gradient_handler_cfg = gpc.config.get("gradient_handler", None)
|
||||
if gradient_handler_cfg is None:
|
||||
# if gradient handler is not specified in the configuration file,
|
||||
# check in the following order
|
||||
@@ -355,54 +370,63 @@ def initialize(model: nn.Module,
|
||||
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
|
||||
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
|
||||
if isinstance(optimizer, ShardedOptimizerV2):
|
||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
||||
gradient_handler_cfg = [dict(type="ZeROGradientHandler")]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
elif is_using_ddp() and MOE_CONTEXT.is_initialized:
|
||||
gradient_handler_cfg = [dict(type='MoeGradientHandler')]
|
||||
gradient_handler_cfg = [dict(type="MoeGradientHandler")]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
elif is_using_sequence():
|
||||
model = DDP(model,
|
||||
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
**torch_ddp_cfg)
|
||||
model = DDP(
|
||||
model,
|
||||
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
**torch_ddp_cfg,
|
||||
)
|
||||
if verbose:
|
||||
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
|
||||
ranks=[0])
|
||||
logger.info(
|
||||
"Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism", ranks=[0]
|
||||
)
|
||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||
model = DDP(model,
|
||||
process_group=gpc.get_group(ParallelMode.DATA),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
**torch_ddp_cfg)
|
||||
model = DDP(
|
||||
model,
|
||||
process_group=gpc.get_group(ParallelMode.DATA),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
**torch_ddp_cfg,
|
||||
)
|
||||
if verbose:
|
||||
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||
logger.info("Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism", ranks=[0])
|
||||
elif is_using_ddp():
|
||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||||
gradient_handler_cfg = [dict(type="DataParallelGradientHandler")]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Data parallel training is detected when using pipeline parallel, "
|
||||
"DataParallelGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
# add pipeline parallel gradient handler, if pipeline shared module is detected
|
||||
for param in model.parameters():
|
||||
if getattr(param, 'pipeline_shared_module_pg', None) is not None:
|
||||
if getattr(param, "pipeline_shared_module_pg", None) is not None:
|
||||
if gradient_handler_cfg is None:
|
||||
gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')]
|
||||
gradient_handler_cfg = [dict(type="PipelineSharedModuleGradientHandler")]
|
||||
else:
|
||||
gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler'))
|
||||
gradient_handler_cfg.append(dict(type="PipelineSharedModuleGradientHandler"))
|
||||
if verbose:
|
||||
logger.info(
|
||||
"pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
break
|
||||
else:
|
||||
if not isinstance(gradient_handler_cfg, list):
|
||||
@@ -418,7 +442,7 @@ def initialize(model: nn.Module,
|
||||
# initialize schedule for engine
|
||||
if is_using_pp():
|
||||
tensor_shape = get_tensor_shape()
|
||||
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
|
||||
use_interleaved = hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks")
|
||||
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
|
||||
scatter_gather = True
|
||||
else:
|
||||
@@ -426,14 +450,16 @@ def initialize(model: nn.Module,
|
||||
if use_interleaved:
|
||||
if isinstance(model, nn.Sequential):
|
||||
model = nn.ModuleList([model])
|
||||
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
gpc.config.model.num_chunks,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather)
|
||||
schedule = InterleavedPipelineSchedule(
|
||||
gpc.config.NUM_MICRO_BATCHES,
|
||||
gpc.config.model.num_chunks,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather,
|
||||
)
|
||||
else:
|
||||
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather)
|
||||
schedule = PipelineSchedule(
|
||||
gpc.config.NUM_MICRO_BATCHES, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather
|
||||
)
|
||||
else:
|
||||
schedule = NonPipelineSchedule()
|
||||
|
||||
@@ -443,7 +469,8 @@ def initialize(model: nn.Module,
|
||||
logger.warning(
|
||||
"No PyTorch DDP or gradient handler is set up, please make sure you do not need "
|
||||
"to all-reduce the gradients after a training step.",
|
||||
ranks=[0])
|
||||
ranks=[0],
|
||||
)
|
||||
else:
|
||||
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
|
||||
|
||||
@@ -452,7 +479,7 @@ def initialize(model: nn.Module,
|
||||
optimizer = OptimizerWrapper(optim=optimizer)
|
||||
|
||||
# gradient accumulation
|
||||
grad_accum_size = gpc.config.get('gradient_accumulation', None)
|
||||
grad_accum_size = gpc.config.get("gradient_accumulation", None)
|
||||
if grad_accum_size is not None:
|
||||
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
|
||||
model=model,
|
||||
@@ -460,13 +487,16 @@ def initialize(model: nn.Module,
|
||||
dataloader=train_dataloader,
|
||||
accumulate_size=grad_accum_size,
|
||||
gradient_handlers=gradient_handlers,
|
||||
lr_scheduler=lr_scheduler)
|
||||
engine = Engine(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
gradient_handlers=gradient_handlers,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
ophook_list=ophooks,
|
||||
schedule=schedule)
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
engine = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
gradient_handlers=gradient_handlers,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
ophook_list=ophooks,
|
||||
schedule=schedule,
|
||||
)
|
||||
|
||||
return engine, train_dataloader, test_dataloader, lr_scheduler
|
||||
|
Reference in New Issue
Block a user