mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
@@ -11,6 +11,7 @@ from torch.distributed import ProcessGroup
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
@@ -26,7 +27,7 @@ from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
from colossalai.utils import disposable, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .gemini_ddp import GeminiDDP
|
||||
@@ -233,7 +234,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
grad_chunk.l2_norm = None # clear l2 norm
|
||||
|
||||
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
|
||||
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
|
||||
for group, part_norm in group_to_norm.items():
|
||||
comm_buffer.fill_(part_norm)
|
||||
dist.all_reduce(comm_buffer, group=group)
|
||||
@@ -314,10 +315,10 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
continue
|
||||
|
||||
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
|
||||
self.chunk_manager.move_chunk(chunk32, get_current_device())
|
||||
self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device())
|
||||
# stores grad now
|
||||
self.chunk_manager.move_chunk(chunk16, get_current_device())
|
||||
self.module.set_chunk_grad_device(chunk16, get_current_device())
|
||||
self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device())
|
||||
self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device())
|
||||
fp32_params_used_cuda_margin_mem += chunk32.payload_mem
|
||||
|
||||
for group in self.param_groups:
|
||||
@@ -328,7 +329,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
state = self.optim.state[fake_param]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(get_current_device())
|
||||
state[k] = v.to(get_accelerator().get_current_device())
|
||||
|
||||
def _register_states_(self):
|
||||
for group in self.optim.param_groups:
|
||||
@@ -551,7 +552,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
self,
|
||||
param_id: int,
|
||||
state_names: list,
|
||||
device: torch.device = get_current_device(),
|
||||
device: torch.device = get_accelerator().get_current_device(),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user