mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 15:11:20 +00:00
@@ -12,7 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import colossalai.utils.device as device_utils
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||
BF16MixedPrecisionMixin,
|
||||
FP16MixedPrecisionMixin,
|
||||
@@ -22,9 +22,6 @@ from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
# from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device
|
||||
|
||||
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||
|
||||
@@ -183,7 +180,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# initialize communication stream for
|
||||
# communication-computation overlapping
|
||||
if self._overlap_communication:
|
||||
self._comm_stream = device_utils.Stream()
|
||||
self._comm_stream = get_accelerator().Stream()
|
||||
|
||||
# reduction hook is only used if overlapping communication
|
||||
# or stage 2 is used
|
||||
@@ -217,7 +214,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
return len(self._working_param_groups)
|
||||
|
||||
def _sanity_checks(self):
|
||||
assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required"
|
||||
assert get_accelerator().name in ["cuda", "npu"], "device is required"
|
||||
for param_group in self.optim.param_groups:
|
||||
group_params = param_group["params"]
|
||||
for param in group_params:
|
||||
@@ -228,7 +225,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
def _create_master_param_current_rank(self, param_list):
|
||||
# split each param evenly by world size
|
||||
params_current_rank = []
|
||||
device = "cpu" if self._cpu_offload else get_current_device()
|
||||
device = "cpu" if self._cpu_offload else get_accelerator().get_current_device()
|
||||
|
||||
for param in param_list:
|
||||
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
|
||||
@@ -340,11 +337,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if len(moe_grad_list) > 0:
|
||||
moe_flat_grads.record_stream(stream)
|
||||
# waiting for ops in the default stream finishing
|
||||
stream.wait_stream(device_utils.current_stream())
|
||||
stream.wait_stream(get_accelerator().current_stream())
|
||||
else:
|
||||
stream = device_utils.current_stream()
|
||||
stream = get_accelerator().current_stream()
|
||||
|
||||
with device_utils.stream(stream):
|
||||
with get_accelerator().stream(stream):
|
||||
group_id = self._bucket_store.current_group_id
|
||||
|
||||
if self.moe_extra_dp_pg is None:
|
||||
@@ -486,7 +483,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
# clear reduced grads
|
||||
if self._overlap_communication:
|
||||
device_utils.synchronize()
|
||||
get_accelerator().synchronize()
|
||||
|
||||
self.zero_grad()
|
||||
|
||||
@@ -505,7 +502,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
# clear reduced grads
|
||||
if self._overlap_communication:
|
||||
device_utils.synchronize()
|
||||
get_accelerator().synchronize()
|
||||
|
||||
self.zero_grad()
|
||||
|
||||
@@ -621,7 +618,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||
|
||||
# update working partition updated by the current rank
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||
for idx, splited_param in enumerate(master_working_param):
|
||||
@@ -661,7 +658,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float)
|
||||
total_norm_cuda = torch.tensor(
|
||||
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||
)
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
|
||||
@@ -673,7 +672,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
# Sum across all model parallel GPUs.
|
||||
total_norm_exponentiated_cuda = torch.tensor(
|
||||
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float
|
||||
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
|
||||
@@ -765,7 +764,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
Dict: the pytorch form state_dict
|
||||
"""
|
||||
zero_state = dict()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
for param, state in self.optim.state.items():
|
||||
zero_state[param] = copy.deepcopy(state)
|
||||
for k, v in state.items():
|
||||
@@ -827,7 +826,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
ret_block = dict()
|
||||
ret_block_size = 0
|
||||
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
local_states = self.optim.state_dict()["state"]
|
||||
for param_idx, states in local_states.items():
|
||||
current_block_size = 0
|
||||
|
Reference in New Issue
Block a user