Merge pull request #5310 from hpcaitech/feature/npu

Feature/npu
This commit is contained in:
Frank Lee
2024-01-29 13:49:39 +08:00
committed by GitHub
271 changed files with 3567 additions and 8915 deletions

View File

@@ -12,7 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
import colossalai.utils.device as device_utils
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin,
@@ -22,9 +22,6 @@ from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.tensor.moe_tensor.api import is_moe_tensor
# from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, ParameterStore
@@ -183,7 +180,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# initialize communication stream for
# communication-computation overlapping
if self._overlap_communication:
self._comm_stream = device_utils.Stream()
self._comm_stream = get_accelerator().Stream()
# reduction hook is only used if overlapping communication
# or stage 2 is used
@@ -217,7 +214,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return len(self._working_param_groups)
def _sanity_checks(self):
assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required"
assert get_accelerator().name in ["cuda", "npu"], "device is required"
for param_group in self.optim.param_groups:
group_params = param_group["params"]
for param in group_params:
@@ -228,7 +225,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size
params_current_rank = []
device = "cpu" if self._cpu_offload else get_current_device()
device = "cpu" if self._cpu_offload else get_accelerator().get_current_device()
for param in param_list:
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
@@ -340,11 +337,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if len(moe_grad_list) > 0:
moe_flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing
stream.wait_stream(device_utils.current_stream())
stream.wait_stream(get_accelerator().current_stream())
else:
stream = device_utils.current_stream()
stream = get_accelerator().current_stream()
with device_utils.stream(stream):
with get_accelerator().stream(stream):
group_id = self._bucket_store.current_group_id
if self.moe_extra_dp_pg is None:
@@ -486,7 +483,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads
if self._overlap_communication:
device_utils.synchronize()
get_accelerator().synchronize()
self.zero_grad()
@@ -505,7 +502,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads
if self._overlap_communication:
device_utils.synchronize()
get_accelerator().synchronize()
self.zero_grad()
@@ -621,7 +618,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank
device = get_current_device()
device = get_accelerator().get_current_device()
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param):
@@ -661,7 +658,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float)
total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item()
@@ -673,7 +672,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# Sum across all model parallel GPUs.
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
)
torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
@@ -765,7 +764,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
Dict: the pytorch form state_dict
"""
zero_state = dict()
device = get_current_device()
device = get_accelerator().get_current_device()
for param, state in self.optim.state.items():
zero_state[param] = copy.deepcopy(state)
for k, v in state.items():
@@ -827,7 +826,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
ret_block = dict()
ret_block_size = 0
device = get_current_device()
device = get_accelerator().get_current_device()
local_states = self.optim.state_dict()["state"]
for param_idx, states in local_states.items():
current_block_size = 0