[utils] refactor parallel layers checkpoint and bcast model on loading checkpoint (#1548)

* refactor parallel layer

* broadcast rank0 model after load ckpt
This commit is contained in:
ver217
2022-09-06 20:18:35 +08:00
committed by GitHub
parent 2bed096848
commit ae71036cd2
6 changed files with 131 additions and 94 deletions

View File

@@ -3,9 +3,9 @@ from itertools import chain
import torch
import torch.distributed as dist
from colossalai.communication.collective import scatter_object_list
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.constants import IS_TENSOR_PARALLEL
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
@@ -190,6 +190,15 @@ def save_checkpoint(file,
torch.save(checkpoint, file, **kwargs)
def broadcast_model(model: torch.nn.Module):
src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0]
for p in model.parameters():
if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0:
group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group(
ParallelMode.TENSOR)
dist.broadcast(p, src_rank, group=group)
def load_checkpoint(
file,
model: torch.nn.Module,
@@ -225,6 +234,7 @@ def load_checkpoint(
model_state = partition_pipeline_parallel_state_dict(model, model_state)
try:
model.load_state_dict(model_state, strict=strict)
broadcast_model(model)
except RuntimeError as e:
error_msgs = str(e)
if error_msgs.startswith("Error(s) in loading state_dict for "):