[zero] refactor ShardedParamV2 for convenience (#742)

This commit is contained in:
HELSON
2022-04-13 14:54:26 +08:00
committed by GitHub
parent 340e59f968
commit 22c4b88d56
16 changed files with 98 additions and 61 deletions

View File

@@ -215,7 +215,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
assert hasattr(param, 'colo_attr')
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
param.colo_attr.remove_torch_payload()
param.colo_attr.set_data_none()
del self.param_list
@@ -252,11 +252,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if param.grad is not None:
param.grad = param.grad.to(target_device)
param.colo_attr = ShardedParamV2(param, rm_torch_payload=False)
param.colo_attr = ShardedParamV2(param, set_data_none=False)
if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
param.data = param.colo_attr.data_payload # set param.data to payload
# mark whether the param is replicated
param.colo_attr.is_replicated = self.is_replicated

View File

@@ -260,7 +260,7 @@ class ShardedModelV2(nn.Module):
if not p.colo_attr.param_is_sharded:
tensor_list.append(p.colo_attr.sharded_data_tensor)
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
p.colo_attr.remove_torch_payload()
p.colo_attr.set_data_none()
self.shard_strategy.shard(tensor_list, self.process_group)
# 4. set all parameters' grad to None
@@ -357,8 +357,8 @@ class ShardedModelV2(nn.Module):
assert param.colo_attr.saved_grad.is_null(
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.saved_grad.reset_payload(grad)
param.colo_attr.sharded_data_tensor.reset_payload(grad) # release the memory of param
param.colo_attr.reset_grad_payload(grad)
param.colo_attr.reset_grad_payload(grad) # release the memory of param
if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True
@@ -367,9 +367,9 @@ class ShardedModelV2(nn.Module):
fp32_grad = cast_tensor_to_fp32(grad)
if param.colo_attr.saved_grad.is_null():
param.colo_attr.saved_grad.reset_payload(fp32_grad)
param.colo_attr.reset_grad_payload(fp32_grad)
else:
param.colo_attr.saved_grad.payload.add_(fp32_grad.view_as(param.colo_attr.saved_grad.payload))
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))
# keep saved_grad in HOLD state
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
@@ -377,11 +377,11 @@ class ShardedModelV2(nn.Module):
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
for p in self.sharded_params:
p.data = p.colo_attr.sharded_data_tensor.payload
p.data = p.colo_attr.data_payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
for p in self.sharded_params:
p.colo_attr.remove_torch_payload()
p.colo_attr.set_data_none()
return gathered_state_dict
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):

View File

@@ -14,6 +14,6 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
if shard_flag:
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
param.data = copy.deepcopy(zero_param.colo_attr.sharded_data_tensor.payload)
param.data = copy.deepcopy(zero_param.colo_attr.data_payload)
if shard_flag:
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])

View File

@@ -266,8 +266,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if shard_flag:
# we always shard replicated paramters
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = StatefulTensor(
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device)))
self.master_params[p] = StatefulTensor(cast_tensor_to_fp32(p.colo_attr.data_payload.to(self.device)))
if shard_flag:
# In this branch, there's no need to shard param
# So we gather here
@@ -296,10 +295,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.colo_attr.saved_grad.payload here
p.data = p.colo_attr.saved_grad.payload
p.grad = p.colo_attr.saved_grad.payload
p.data = p.colo_attr.grad_payload
p.grad = p.colo_attr.grad_payload
# Set p.data to empty tensor, in case of memory leaking
p.colo_attr.remove_torch_payload()
p.colo_attr.set_data_none()
def _point_param_fp16_to_master_param(self):
# assign master param pointers to p.data.
@@ -325,9 +324,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.data = self.master_params[p].payload
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.remove_torch_payload()
p.colo_attr.reset_data_payload(
colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.set_data_none()
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
# We gather full fp16 param here

View File

@@ -10,10 +10,20 @@ from typing import List
# empty tensor is expected to raise error when get used
FAKE_EMPTY_TENSOR = torch.BoolTensor([], device='cpu')
EMPTY_TENSOR_DICT = {}
def get_empty_tensor(device: torch.device, dtype: torch.dtype):
key = (device, dtype)
if key not in EMPTY_TENSOR_DICT:
EMPTY_TENSOR_DICT[key] = FAKE_EMPTY_TENSOR.to(device, dtype)
return EMPTY_TENSOR_DICT[key]
class ShardedParamV2(object):
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None:
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
# This attribute must be initialized in ShardedModel
@@ -25,24 +35,47 @@ class ShardedParamV2(object):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# So we can not empty the .data at this time
self.param = param
if rm_torch_payload:
self.remove_torch_payload()
if set_data_none:
self.set_data_none()
def get_payload_tensors(self) -> List[StatefulTensor]:
"""returns stateful tensors kept by this class.
"""
return [self._sharded_data_tensor]
def remove_torch_payload(self):
self.param.data = FAKE_EMPTY_TENSOR.to(self._sharded_data_tensor.device, self._sharded_data_tensor.dtype)
def set_data_none(self):
self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype)
def set_grad_none(self):
self.saved_grad.set_null()
@property
def sharded_data_tensor(self):
return self._sharded_data_tensor
@property
def data_payload(self):
return self.sharded_data_tensor.payload
@property
def grad_payload(self):
assert not self.saved_grad.is_null()
return self.saved_grad.payload
@property
def param_is_sharded(self):
return self._sharded_data_tensor.is_sharded
return self.sharded_data_tensor.is_sharded
def reset_data_payload(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.sharded_data_tensor.reset_payload(tensor)
self.set_data_none()
def reset_grad_payload(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.saved_grad.reset_payload(tensor)
def get_memory_usage(self) -> Tuple[int, int]:
"""
@@ -63,11 +96,11 @@ class ShardedParamV2(object):
cpu_mem_use += t_cpu
address_set = set()
_update_mem_use(self.sharded_data_tensor.payload)
address_set.add(self.sharded_data_tensor.payload.data_ptr())
_update_mem_use(self.data_payload)
address_set.add(self.data_payload.data_ptr())
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
_update_mem_use(self.saved_grad.payload)
_update_mem_use(self.grad_payload)
address_set.add(self.saved_grad.data_ptr())
if self.param.data is not None and self.param.data.data_ptr() not in address_set:

View File

@@ -9,6 +9,7 @@ class ShardedTensor(StatefulTensor):
r"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
assert tensor.requires_grad is False
super().__init__(tensor, state)
# kept the shape, numel and dtype of the init tensor.
@@ -17,6 +18,11 @@ class ShardedTensor(StatefulTensor):
self._origin_dtype = tensor.dtype
self._is_sharded = False
@property
def dtype(self) -> torch.dtype:
assert self._payload.dtype == self._origin_dtype
return self._payload.dtype
@property
def origin_numel(self) -> int:
return self._origin_numel

View File

@@ -19,11 +19,11 @@ class StatefulTensor(object):
https://arxiv.org/abs/2108.05818
"""
def __init__(self, tensor: torch.Tensor, state: Optional[TensorState] = TensorState.HOLD) -> None:
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
self._payload = tensor
if self._state == TensorState.FREE:
assert self._payload is None, f"payload has to None if {self._state}"
assert self._payload is None, f"payload has to None if state is {self._state}"
def data_ptr(self):
if self._payload is None:
@@ -50,13 +50,13 @@ class StatefulTensor(object):
self._payload = None
@property
def payload(self) -> int:
def payload(self) -> Optional[torch.Tensor]:
return self._payload
def copy_payload(self, tensor) -> int:
def copy_payload(self, tensor) -> None:
self._payload.view(-1).copy_(tensor.view(-1))
def reset_payload(self, tensor) -> int:
def reset_payload(self, tensor) -> None:
del self._payload
self._payload = tensor
self.trans_state(TensorState.HOLD)
@@ -67,15 +67,14 @@ class StatefulTensor(object):
@property
def dtype(self) -> torch.dtype:
assert self._payload.dtype == self._origin_dtype
return self._origin_dtype
return self._payload.dtype
@property
def shape(self):
return self._payload.shape
def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
@property
def shape(self):
return self._payload.shape

View File

@@ -60,7 +60,7 @@ class ZeroHook(BaseOpHook):
self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.data = param.colo_attr.sharded_data_tensor.payload
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
def post_fwd_exec(self, module: torch.nn.Module, *args):
@@ -79,7 +79,7 @@ class ZeroHook(BaseOpHook):
# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()
param.colo_attr.set_data_none()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
@@ -105,7 +105,7 @@ class ZeroHook(BaseOpHook):
self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.data = param.colo_attr.sharded_data_tensor.payload
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
def post_bwd_exec(self, module: torch.nn.Module, input):
@@ -124,7 +124,7 @@ class ZeroHook(BaseOpHook):
# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()
param.colo_attr.set_data_none()
def pre_iter(self):
pass