mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .low_level_optim import LowLevelZeroOptimizer
|
||||
|
||||
__all__ = ['LowLevelZeroOptimizer']
|
||||
__all__ = ["LowLevelZeroOptimizer"]
|
||||
|
@@ -44,8 +44,8 @@ def shuffle_by_round_robin(tensor_list, num_partitions):
|
||||
for partition_id in range(partitions_count):
|
||||
partition_tensors = partitions[partition_id]
|
||||
for item in partition_tensors:
|
||||
tensor_index_mapping[item['index']] = len(new_tensor_list)
|
||||
new_tensor_list.append(item['tensor'])
|
||||
tensor_index_mapping[item["index"]] = len(new_tensor_list)
|
||||
new_tensor_list.append(item["tensor"])
|
||||
|
||||
return new_tensor_list, tensor_index_mapping
|
||||
|
||||
@@ -107,11 +107,13 @@ def split_by_dtype(tensor_list):
|
||||
return buckets
|
||||
|
||||
|
||||
def reduce_tensor_dp_group(tensor: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dst_local_rank: Optional[int] = None,
|
||||
dst_global_rank: Optional[int] = None,
|
||||
group: Optional[dist.ProcessGroup] = None):
|
||||
def reduce_tensor_dp_group(
|
||||
tensor: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
dst_local_rank: Optional[int] = None,
|
||||
dst_global_rank: Optional[int] = None,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
):
|
||||
"""
|
||||
Reduce the tensor in the data parallel process group
|
||||
|
||||
@@ -173,7 +175,7 @@ def has_inf_or_nan(tensor):
|
||||
raise
|
||||
return True
|
||||
else:
|
||||
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
|
||||
if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -184,8 +186,7 @@ def release_param_grad(tensor_list):
|
||||
|
||||
|
||||
def calculate_global_norm_from_list(norm_list):
|
||||
""" Compute total from a list of norms
|
||||
"""
|
||||
"""Compute total from a list of norms"""
|
||||
total_norm = 0.0
|
||||
for norm in norm_list:
|
||||
total_norm += norm**2.0
|
||||
@@ -221,7 +222,7 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro
|
||||
total_norm = 0.0
|
||||
for g in gradients:
|
||||
param_norm = g.data.double().norm(2)
|
||||
total_norm += param_norm.item()**2
|
||||
total_norm += param_norm.item() ** 2
|
||||
|
||||
# Sum across all model parallel GPUs.
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
@@ -230,9 +231,9 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro
|
||||
if tp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
|
||||
|
||||
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
||||
total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type)
|
||||
|
||||
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
|
||||
if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm:
|
||||
total_norm = -1
|
||||
|
||||
return total_norm
|
||||
|
@@ -3,4 +3,4 @@ from .gradient_store import GradientStore
|
||||
from .parameter_store import ParameterStore
|
||||
from .tensor_bucket import TensorBucket
|
||||
|
||||
__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket']
|
||||
__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"]
|
||||
|
@@ -3,7 +3,6 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class BaseStore:
|
||||
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
self._world_size = dist.get_world_size(group=torch_pg)
|
||||
self._local_rank = dist.get_rank(group=torch_pg)
|
||||
|
@@ -9,7 +9,6 @@ from .base_store import BaseStore
|
||||
|
||||
|
||||
class BucketStore(BaseStore):
|
||||
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
|
||||
@@ -38,8 +37,7 @@ class BucketStore(BaseStore):
|
||||
return self._num_elements_in_bucket
|
||||
|
||||
def reset_num_elements_in_bucket(self):
|
||||
"""Set the number of elements in bucket to zero.
|
||||
"""
|
||||
"""Set the number of elements in bucket to zero."""
|
||||
|
||||
self._num_elements_in_bucket = 0
|
||||
|
||||
@@ -54,7 +52,7 @@ class BucketStore(BaseStore):
|
||||
|
||||
self._param_list.append(param)
|
||||
self._padding_size.append(padding_size)
|
||||
self._num_elements_in_bucket += (param.numel() + padding_size)
|
||||
self._num_elements_in_bucket += param.numel() + padding_size
|
||||
self.current_group_id = group_id
|
||||
|
||||
# number of tensors in current bucket
|
||||
@@ -119,8 +117,7 @@ class BucketStore(BaseStore):
|
||||
return self.grad_to_param_mapping[id(grad)]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the bucket storage after reduction, only release the tensors have been reduced
|
||||
"""
|
||||
"""Reset the bucket storage after reduction, only release the tensors have been reduced"""
|
||||
cur_offset = self.offset_list.pop(0)
|
||||
self._param_list = self._param_list[cur_offset:]
|
||||
self._padding_size = self._padding_size[cur_offset:]
|
||||
|
@@ -1,13 +1,11 @@
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class GradientStore(BaseStore):
|
||||
|
||||
def __init__(self, *args, partition_grad: bool = False):
|
||||
super().__init__(*args)
|
||||
"""
|
||||
|
@@ -5,7 +5,6 @@ from .base_store import BaseStore
|
||||
|
||||
|
||||
class ParameterStore(BaseStore):
|
||||
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
|
||||
|
@@ -2,7 +2,6 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
|
||||
class TensorBucket:
|
||||
|
||||
def __init__(self, size):
|
||||
self._max_size = size
|
||||
self._current_size = 0
|
||||
@@ -26,8 +25,7 @@ class TensorBucket:
|
||||
tensor_size = tensor.numel()
|
||||
|
||||
if not allow_oversize and self.will_exceed_max_size(tensor_size):
|
||||
msg = f"The param bucket max size {self._max_size} is exceeded" \
|
||||
+ f"by tensor (size {tensor_size})"
|
||||
msg = f"The param bucket max size {self._max_size} is exceeded" + f"by tensor (size {tensor_size})"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
self._bucket.append(tensor)
|
||||
|
@@ -17,6 +17,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||
)
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
# from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
@@ -32,19 +33,21 @@ from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||
|
||||
|
||||
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
num_working_param_groups: int,
|
||||
grad_store: GradientStore,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||
max_scale)
|
||||
def __init__(
|
||||
self,
|
||||
num_working_param_groups: int,
|
||||
grad_store: GradientStore,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
|
||||
)
|
||||
self.num_working_param_groups = num_working_param_groups
|
||||
self.grad_store = grad_store
|
||||
|
||||
@@ -57,32 +60,31 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
|
||||
class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
||||
"""
|
||||
"""Optimizer used for ZeRO-1 and ZeRO-2."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.,
|
||||
backoff_factor: float = .5,
|
||||
growth_interval: int = 2000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: int = 2**24,
|
||||
clip_grad_norm: float = 0.0, # grad clipping
|
||||
verbose: bool = False,
|
||||
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = False,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: int = 2**24,
|
||||
clip_grad_norm: float = 0.0, # grad clipping
|
||||
verbose: bool = False,
|
||||
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = False,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
||||
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
||||
self._logger = get_dist_logger()
|
||||
self._verbose = verbose
|
||||
|
||||
@@ -115,7 +117,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
if forced_dtype:
|
||||
for group in self.optim.param_groups:
|
||||
group_params = group['params']
|
||||
group_params = group["params"]
|
||||
for param in group_params:
|
||||
param.data = param.data.to(forced_dtype)
|
||||
self._dtype = forced_dtype
|
||||
@@ -134,7 +136,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# and add buffers to parameter store for future access
|
||||
for group_id, param_group in enumerate(self.optim.param_groups):
|
||||
group_params = list()
|
||||
for param in param_group['params']:
|
||||
for param in param_group["params"]:
|
||||
if param.requires_grad:
|
||||
group_params.append(param)
|
||||
|
||||
@@ -148,7 +150,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# need to replace the params in the `params` field in the optimizer
|
||||
# so that when the optimizer calls step(), it only updates the tensors
|
||||
# managed by this data parallel rank
|
||||
param_group['params'] = master_param_current_rank
|
||||
param_group["params"] = master_param_current_rank
|
||||
|
||||
# intialize communication stream for
|
||||
# communication-compuation overlapping
|
||||
@@ -164,15 +166,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# initialize mixed precision mixin
|
||||
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
|
||||
if self._dtype is torch.float16:
|
||||
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups,
|
||||
self._grad_store,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(
|
||||
self.num_param_groups,
|
||||
self._grad_store,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
elif self._dtype is torch.bfloat16:
|
||||
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
|
||||
|
||||
@@ -185,17 +189,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
return len(self._working_param_groups)
|
||||
|
||||
def _sanity_checks(self):
|
||||
assert torch.cuda.is_available(), 'CUDA is required'
|
||||
assert torch.cuda.is_available(), "CUDA is required"
|
||||
for param_group in self.optim.param_groups:
|
||||
group_params = param_group['params']
|
||||
group_params = param_group["params"]
|
||||
for param in group_params:
|
||||
assert param.dtype == self._dtype, \
|
||||
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
||||
assert (
|
||||
param.dtype == self._dtype
|
||||
), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
||||
|
||||
def _create_master_param_current_rank(self, param_list):
|
||||
# split each param evenly by world size
|
||||
params_current_rank = []
|
||||
device = 'cpu' if self._cpu_offload else get_current_device()
|
||||
device = "cpu" if self._cpu_offload else get_current_device()
|
||||
|
||||
for param in param_list:
|
||||
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
|
||||
@@ -275,8 +280,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
sync_tensor(flat_grads_per_rank[rank], grad_list)
|
||||
for grad in grad_list:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id,
|
||||
param_id)) < self._world_size:
|
||||
if (
|
||||
len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id))
|
||||
< self._world_size
|
||||
):
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
else:
|
||||
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
|
||||
@@ -307,8 +314,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# or got a grad of param from another group
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
|
||||
group_id != self._bucket_store.current_group_id:
|
||||
if (
|
||||
self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size
|
||||
or group_id != self._bucket_store.current_group_id
|
||||
):
|
||||
self._run_reduction()
|
||||
|
||||
padding_size = self._param_store.get_param_padding_size(param)
|
||||
@@ -319,8 +328,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
################################
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
assert not(self._partition_grads and not self.require_grad_sync), \
|
||||
"ZeRO2(partition_grads) and no_sync are not compatible"
|
||||
assert not (
|
||||
self._partition_grads and not self.require_grad_sync
|
||||
), "ZeRO2(partition_grads) and no_sync are not compatible"
|
||||
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
@@ -339,8 +349,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
self.zero_grad()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
assert not(self._partition_grads and not self.require_grad_sync), \
|
||||
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||
assert not (
|
||||
self._partition_grads and not self.require_grad_sync
|
||||
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||
|
||||
if self.mixed_precision_mixin is not None:
|
||||
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
|
||||
@@ -380,14 +391,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
####################
|
||||
|
||||
def step(self, closure=None):
|
||||
assert closure is None, 'closure is not supported by step()'
|
||||
assert closure is None, "closure is not supported by step()"
|
||||
if not self.require_grad_sync:
|
||||
return
|
||||
|
||||
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
||||
self._grad_store.reset_all_gradients()
|
||||
if self._verbose:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self._logger.info(f"Found overflow. Skip step")
|
||||
self.zero_grad()
|
||||
return
|
||||
|
||||
@@ -428,7 +439,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
self._grad_store.reset_grads_by_group_id(group_id)
|
||||
|
||||
# update the params in the optimizer
|
||||
self.optim.param_groups[group_id]['params'] = real_master_params[group_id]
|
||||
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
|
||||
|
||||
# unscale and clip grads
|
||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
@@ -445,16 +456,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# update working partition updated by the current rank
|
||||
dtype = real_working_params[0][0].dtype
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_working_param = self.optim.param_groups[group_id]['params']
|
||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||
for idx, splited_param in enumerate(master_working_param):
|
||||
working_param = real_working_params[group_id][idx]
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg)
|
||||
working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param))
|
||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||
|
||||
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
#############################
|
||||
# Mixed Precision Utilities #
|
||||
@@ -466,14 +477,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if self.mixed_precision_mixin is not None:
|
||||
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
|
||||
|
||||
if self._clip_grad_norm > 0.:
|
||||
if self._clip_grad_norm > 0.0:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1:
|
||||
div_scale = clip * div_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
grad.data.mul_(1. / div_scale)
|
||||
grad.data.mul_(1.0 / div_scale)
|
||||
|
||||
############################
|
||||
# Gradient Synchronization #
|
||||
@@ -518,18 +529,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
def pack_group(group):
|
||||
nonlocal start_index
|
||||
packed = {k: v for k, v in group.items() if k != 'params'}
|
||||
packed = {k: v for k, v in group.items() if k != "params"}
|
||||
param_mappings.update(
|
||||
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
|
||||
packed['params'] = [param_mappings[id(p)] for p in group['params']]
|
||||
start_index += len(packed['params'])
|
||||
{id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings}
|
||||
)
|
||||
packed["params"] = [param_mappings[id(p)] for p in group["params"]]
|
||||
start_index += len(packed["params"])
|
||||
return packed
|
||||
|
||||
param_groups = [pack_group(g) for g in self.optim.param_groups]
|
||||
# Remap state to use order indices as keys
|
||||
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}
|
||||
|
||||
return {'state': packed_state, 'param_groups': param_groups}
|
||||
return {"state": packed_state, "param_groups": param_groups}
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
"""Return a state_dict same with DDP
|
||||
@@ -541,14 +553,15 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
for param, state in self.optim.state.items():
|
||||
zero_state[param] = copy.deepcopy(state)
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
working_param = self._param_store.master_to_working_param[id(param)]
|
||||
gather_tensor = [
|
||||
torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)
|
||||
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
|
||||
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(
|
||||
working_param).cpu()
|
||||
param_state = (
|
||||
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
)
|
||||
zero_state[param][k] = param_state
|
||||
|
||||
states_dict = self._pack_state(zero_state)
|
||||
@@ -562,16 +575,16 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
state_dict (dict): A pytorch form state_dict
|
||||
"""
|
||||
zero_state_dict = copy.deepcopy(state_dict)
|
||||
for param_idx, state in zero_state_dict['state'].items():
|
||||
for param_idx, state in zero_state_dict["state"].items():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self._world_size)
|
||||
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()
|
||||
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
|
||||
@@ -588,7 +601,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
ret_block = dict()
|
||||
ret_block_size = 0
|
||||
|
||||
local_states = self.optim.state_dict()['state']
|
||||
local_states = self.optim.state_dict()["state"]
|
||||
for param_idx, states in local_states.items():
|
||||
current_block_size = 0
|
||||
current_block = copy.deepcopy(states)
|
||||
@@ -601,11 +614,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
working_param = self._param_store.master_to_working_param[id(master_param)]
|
||||
|
||||
for k, v in states.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)]
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
state_tensor = [torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)]
|
||||
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
|
||||
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(
|
||||
working_param).cpu()
|
||||
state_tensor = (
|
||||
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
)
|
||||
current_block_size += state_tensor.numel()
|
||||
current_block[k] = state_tensor
|
||||
|
||||
|
Reference in New Issue
Block a user