mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[llama] polish training script and fix optim ckpt (#5368)
This commit is contained in:
@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
@@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
tp_group=self.tp_group,
|
||||
use_zero=self.use_zero,
|
||||
inplace=False,
|
||||
device=torch.device("cuda"),
|
||||
device=get_current_device(),
|
||||
)
|
||||
|
||||
if self.pp_size == 1:
|
||||
@@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# First gather Zero shards.
|
||||
if use_zero:
|
||||
v = v.cuda()
|
||||
v = v.to(get_current_device())
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
|
Reference in New Issue
Block a user