modefied the pp build for ckpt adaptation (#803)

This commit is contained in:
LuGY
2022-04-24 12:23:16 +08:00
committed by GitHub
parent 8789850eea
commit c1e8d2001e
2 changed files with 6 additions and 3 deletions

View File

@@ -399,6 +399,8 @@ def initialize(model: nn.Module,
else:
scatter_gather = False
if use_interleaved:
if isinstance(model, nn.Sequential):
model = nn.ModuleList([model])
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
gpc.config.model.num_chunks,
tensor_shape=tensor_shape,
@@ -434,7 +436,6 @@ def initialize(model: nn.Module,
accumulate_size=grad_accum_size,
gradient_handlers=gradient_handlers,
lr_scheduler=lr_scheduler)
engine = Engine(model=model,
optimizer=optimizer,
criterion=criterion,