mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[zero] refactor ShardedParamV2 for convenience (#742)
This commit is contained in:
@@ -215,7 +215,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
|
||||
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
||||
param.colo_attr.remove_torch_payload()
|
||||
param.colo_attr.set_data_none()
|
||||
|
||||
del self.param_list
|
||||
|
||||
@@ -252,11 +252,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(target_device)
|
||||
|
||||
param.colo_attr = ShardedParamV2(param, rm_torch_payload=False)
|
||||
param.colo_attr = ShardedParamV2(param, set_data_none=False)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
|
||||
param.data = param.colo_attr.data_payload # set param.data to payload
|
||||
|
||||
# mark whether the param is replicated
|
||||
param.colo_attr.is_replicated = self.is_replicated
|
||||
|
@@ -260,7 +260,7 @@ class ShardedModelV2(nn.Module):
|
||||
if not p.colo_attr.param_is_sharded:
|
||||
tensor_list.append(p.colo_attr.sharded_data_tensor)
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
p.colo_attr.remove_torch_payload()
|
||||
p.colo_attr.set_data_none()
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
# 4. set all parameters' grad to None
|
||||
@@ -357,8 +357,8 @@ class ShardedModelV2(nn.Module):
|
||||
assert param.colo_attr.saved_grad.is_null(
|
||||
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||
|
||||
param.colo_attr.saved_grad.reset_payload(grad)
|
||||
param.colo_attr.sharded_data_tensor.reset_payload(grad) # release the memory of param
|
||||
param.colo_attr.reset_grad_payload(grad)
|
||||
param.colo_attr.reset_grad_payload(grad) # release the memory of param
|
||||
|
||||
if param.colo_attr.is_replicated:
|
||||
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||
@@ -367,9 +367,9 @@ class ShardedModelV2(nn.Module):
|
||||
fp32_grad = cast_tensor_to_fp32(grad)
|
||||
|
||||
if param.colo_attr.saved_grad.is_null():
|
||||
param.colo_attr.saved_grad.reset_payload(fp32_grad)
|
||||
param.colo_attr.reset_grad_payload(fp32_grad)
|
||||
else:
|
||||
param.colo_attr.saved_grad.payload.add_(fp32_grad.view_as(param.colo_attr.saved_grad.payload))
|
||||
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))
|
||||
|
||||
# keep saved_grad in HOLD state
|
||||
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
|
||||
@@ -377,11 +377,11 @@ class ShardedModelV2(nn.Module):
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
|
||||
for p in self.sharded_params:
|
||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||
p.data = p.colo_attr.data_payload
|
||||
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
|
||||
for p in self.sharded_params:
|
||||
p.colo_attr.remove_torch_payload()
|
||||
p.colo_attr.set_data_none()
|
||||
return gathered_state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
|
@@ -14,6 +14,6 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
|
||||
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
|
||||
param.data = copy.deepcopy(zero_param.colo_attr.sharded_data_tensor.payload)
|
||||
param.data = copy.deepcopy(zero_param.colo_attr.data_payload)
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])
|
||||
|
@@ -266,8 +266,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
if shard_flag:
|
||||
# we always shard replicated paramters
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.master_params[p] = StatefulTensor(
|
||||
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device)))
|
||||
self.master_params[p] = StatefulTensor(cast_tensor_to_fp32(p.colo_attr.data_payload.to(self.device)))
|
||||
if shard_flag:
|
||||
# In this branch, there's no need to shard param
|
||||
# So we gather here
|
||||
@@ -296,10 +295,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
# If we change p.grad directly
|
||||
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
||||
# We just set p.data = p.colo_attr.saved_grad.payload here
|
||||
p.data = p.colo_attr.saved_grad.payload
|
||||
p.grad = p.colo_attr.saved_grad.payload
|
||||
p.data = p.colo_attr.grad_payload
|
||||
p.grad = p.colo_attr.grad_payload
|
||||
# Set p.data to empty tensor, in case of memory leaking
|
||||
p.colo_attr.remove_torch_payload()
|
||||
p.colo_attr.set_data_none()
|
||||
|
||||
def _point_param_fp16_to_master_param(self):
|
||||
# assign master param pointers to p.data.
|
||||
@@ -325,9 +324,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.data = self.master_params[p].payload
|
||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
|
||||
p.colo_attr.remove_torch_payload()
|
||||
p.colo_attr.reset_data_payload(
|
||||
colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device))
|
||||
p.colo_attr.set_data_none()
|
||||
|
||||
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||
# We gather full fp16 param here
|
||||
|
@@ -10,10 +10,20 @@ from typing import List
|
||||
# empty tensor is expected to raise error when get used
|
||||
FAKE_EMPTY_TENSOR = torch.BoolTensor([], device='cpu')
|
||||
|
||||
EMPTY_TENSOR_DICT = {}
|
||||
|
||||
|
||||
def get_empty_tensor(device: torch.device, dtype: torch.dtype):
|
||||
key = (device, dtype)
|
||||
if key not in EMPTY_TENSOR_DICT:
|
||||
EMPTY_TENSOR_DICT[key] = FAKE_EMPTY_TENSOR.to(device, dtype)
|
||||
|
||||
return EMPTY_TENSOR_DICT[key]
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
||||
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
|
||||
def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None:
|
||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
||||
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
# This attribute must be initialized in ShardedModel
|
||||
@@ -25,24 +35,47 @@ class ShardedParamV2(object):
|
||||
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
# So we can not empty the .data at this time
|
||||
self.param = param
|
||||
if rm_torch_payload:
|
||||
self.remove_torch_payload()
|
||||
if set_data_none:
|
||||
self.set_data_none()
|
||||
|
||||
def get_payload_tensors(self) -> List[StatefulTensor]:
|
||||
"""returns stateful tensors kept by this class.
|
||||
"""
|
||||
return [self._sharded_data_tensor]
|
||||
|
||||
def remove_torch_payload(self):
|
||||
self.param.data = FAKE_EMPTY_TENSOR.to(self._sharded_data_tensor.device, self._sharded_data_tensor.dtype)
|
||||
def set_data_none(self):
|
||||
self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype)
|
||||
|
||||
def set_grad_none(self):
|
||||
self.saved_grad.set_null()
|
||||
|
||||
@property
|
||||
def sharded_data_tensor(self):
|
||||
return self._sharded_data_tensor
|
||||
|
||||
@property
|
||||
def data_payload(self):
|
||||
return self.sharded_data_tensor.payload
|
||||
|
||||
@property
|
||||
def grad_payload(self):
|
||||
assert not self.saved_grad.is_null()
|
||||
return self.saved_grad.payload
|
||||
|
||||
@property
|
||||
def param_is_sharded(self):
|
||||
return self._sharded_data_tensor.is_sharded
|
||||
return self.sharded_data_tensor.is_sharded
|
||||
|
||||
def reset_data_payload(self, tensor: torch.Tensor):
|
||||
assert type(tensor) is torch.Tensor
|
||||
assert tensor.requires_grad is False
|
||||
self.sharded_data_tensor.reset_payload(tensor)
|
||||
self.set_data_none()
|
||||
|
||||
def reset_grad_payload(self, tensor: torch.Tensor):
|
||||
assert type(tensor) is torch.Tensor
|
||||
assert tensor.requires_grad is False
|
||||
self.saved_grad.reset_payload(tensor)
|
||||
|
||||
def get_memory_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
@@ -63,11 +96,11 @@ class ShardedParamV2(object):
|
||||
cpu_mem_use += t_cpu
|
||||
|
||||
address_set = set()
|
||||
_update_mem_use(self.sharded_data_tensor.payload)
|
||||
address_set.add(self.sharded_data_tensor.payload.data_ptr())
|
||||
_update_mem_use(self.data_payload)
|
||||
address_set.add(self.data_payload.data_ptr())
|
||||
|
||||
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.saved_grad.payload)
|
||||
_update_mem_use(self.grad_payload)
|
||||
address_set.add(self.saved_grad.data_ptr())
|
||||
|
||||
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
||||
|
@@ -9,6 +9,7 @@ class ShardedTensor(StatefulTensor):
|
||||
r"""
|
||||
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
||||
"""
|
||||
assert tensor.requires_grad is False
|
||||
super().__init__(tensor, state)
|
||||
|
||||
# kept the shape, numel and dtype of the init tensor.
|
||||
@@ -17,6 +18,11 @@ class ShardedTensor(StatefulTensor):
|
||||
self._origin_dtype = tensor.dtype
|
||||
self._is_sharded = False
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
assert self._payload.dtype == self._origin_dtype
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def origin_numel(self) -> int:
|
||||
return self._origin_numel
|
||||
|
@@ -19,11 +19,11 @@ class StatefulTensor(object):
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
self._state = state
|
||||
self._payload = tensor
|
||||
if self._state == TensorState.FREE:
|
||||
assert self._payload is None, f"payload has to None if {self._state}"
|
||||
assert self._payload is None, f"payload has to None if state is {self._state}"
|
||||
|
||||
def data_ptr(self):
|
||||
if self._payload is None:
|
||||
@@ -50,13 +50,13 @@ class StatefulTensor(object):
|
||||
self._payload = None
|
||||
|
||||
@property
|
||||
def payload(self) -> int:
|
||||
def payload(self) -> Optional[torch.Tensor]:
|
||||
return self._payload
|
||||
|
||||
def copy_payload(self, tensor) -> int:
|
||||
def copy_payload(self, tensor) -> None:
|
||||
self._payload.view(-1).copy_(tensor.view(-1))
|
||||
|
||||
def reset_payload(self, tensor) -> int:
|
||||
def reset_payload(self, tensor) -> None:
|
||||
del self._payload
|
||||
self._payload = tensor
|
||||
self.trans_state(TensorState.HOLD)
|
||||
@@ -67,15 +67,14 @@ class StatefulTensor(object):
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
assert self._payload.dtype == self._origin_dtype
|
||||
return self._origin_dtype
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._payload.shape
|
||||
|
||||
def to(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
||||
|
||||
def to_(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._payload.shape
|
||||
|
@@ -60,7 +60,7 @@ class ZeroHook(BaseOpHook):
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
param.data = param.colo_attr.data_payload
|
||||
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
@@ -79,7 +79,7 @@ class ZeroHook(BaseOpHook):
|
||||
|
||||
# remove torch payload
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.remove_torch_payload()
|
||||
param.colo_attr.set_data_none()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
|
||||
@@ -105,7 +105,7 @@ class ZeroHook(BaseOpHook):
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
param.data = param.colo_attr.data_payload
|
||||
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
@@ -124,7 +124,7 @@ class ZeroHook(BaseOpHook):
|
||||
|
||||
# remove torch payload
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.remove_torch_payload()
|
||||
param.colo_attr.set_data_none()
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user