Merge pull request #5310 from hpcaitech/feature/npu

Feature/npu
This commit is contained in:
Frank Lee
2024-01-29 13:49:39 +08:00
committed by GitHub
271 changed files with 3567 additions and 8915 deletions

View File

@@ -11,6 +11,7 @@ from torch.distributed import ProcessGroup
from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import OptimizerWrapper
@@ -26,7 +27,7 @@ from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from colossalai.utils import disposable, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP
@@ -233,7 +234,7 @@ class GeminiOptimizer(OptimizerWrapper):
grad_chunk.l2_norm = None # clear l2 norm
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
for group, part_norm in group_to_norm.items():
comm_buffer.fill_(part_norm)
dist.all_reduce(comm_buffer, group=group)
@@ -314,10 +315,10 @@ class GeminiOptimizer(OptimizerWrapper):
continue
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
self.chunk_manager.move_chunk(chunk32, get_current_device())
self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device())
# stores grad now
self.chunk_manager.move_chunk(chunk16, get_current_device())
self.module.set_chunk_grad_device(chunk16, get_current_device())
self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device())
self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device())
fp32_params_used_cuda_margin_mem += chunk32.payload_mem
for group in self.param_groups:
@@ -328,7 +329,7 @@ class GeminiOptimizer(OptimizerWrapper):
state = self.optim.state[fake_param]
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(get_current_device())
state[k] = v.to(get_accelerator().get_current_device())
def _register_states_(self):
for group in self.optim.param_groups:
@@ -551,7 +552,7 @@ class GeminiOptimizer(OptimizerWrapper):
self,
param_id: int,
state_names: list,
device: torch.device = get_current_device(),
device: torch.device = get_accelerator().get_current_device(),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""