[llama] polish training script and fix optim ckpt (#5368)

This commit is contained in:
Hongxin Liu
2024-02-06 11:52:17 +08:00
committed by GitHub
parent a5756a8720
commit eb4f2d90f9
2 changed files with 14 additions and 5 deletions

View File

@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
tp_group=self.tp_group,
use_zero=self.use_zero,
inplace=False,
device=torch.device("cuda"),
device=get_current_device(),
)
if self.pp_size == 1:
@@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero:
v = v.cuda()
v = v.to(get_current_device())
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)