This commit is contained in:
wangbluo 2024-11-19 11:52:51 +08:00
parent f4f3d52924
commit 945a67dd61
3 changed files with 12 additions and 7 deletions

View File

@ -82,7 +82,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async, state_dict)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)

View File

@ -45,11 +45,16 @@ class GeneralCheckpointIO(CheckpointIO):
model.load_state_dict(checkpoint, strict=strict)
def save_unsharded_model(
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
state_dict: dict = None,
):
state_dict = model.state_dict()
# TODO(FrankLeeeee): add support for gather_dtensor
if state_dict is None:
state_dict = model.state_dict()
if gather_dtensor:
pass
@ -64,7 +69,7 @@ class GeneralCheckpointIO(CheckpointIO):
else:
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
save_state_dict(model.state_dict(), checkpoint, use_safetensors)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""

View File

@ -192,7 +192,7 @@ def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_state_dict()
exam_state_dict_with_origin()
exam_lazy_from_pretrained()
# exam_lazy_from_pretrained()
@pytest.mark.dist