mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[utils] refactor parallel layers checkpoint and bcast model on loading checkpoint (#1548)
* refactor parallel layer * broadcast rank0 model after load ckpt
This commit is contained in:
@@ -3,9 +3,9 @@ from itertools import chain
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.communication.collective import scatter_object_list
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
@@ -190,6 +190,15 @@ def save_checkpoint(file,
|
||||
torch.save(checkpoint, file, **kwargs)
|
||||
|
||||
|
||||
def broadcast_model(model: torch.nn.Module):
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0]
|
||||
for p in model.parameters():
|
||||
if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0:
|
||||
group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group(
|
||||
ParallelMode.TENSOR)
|
||||
dist.broadcast(p, src_rank, group=group)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
file,
|
||||
model: torch.nn.Module,
|
||||
@@ -225,6 +234,7 @@ def load_checkpoint(
|
||||
model_state = partition_pipeline_parallel_state_dict(model, model_state)
|
||||
try:
|
||||
model.load_state_dict(model_state, strict=strict)
|
||||
broadcast_model(model)
|
||||
except RuntimeError as e:
|
||||
error_msgs = str(e)
|
||||
if error_msgs.startswith("Error(s) in loading state_dict for "):
|
||||
|
Reference in New Issue
Block a user