diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 0284c92f3..c27d7a577 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -215,7 +215,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): assert hasattr(param, 'colo_attr') if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated: dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group) - param.colo_attr.remove_torch_payload() + param.colo_attr.set_data_none() del self.param_list @@ -252,11 +252,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if param.grad is not None: param.grad = param.grad.to(target_device) - param.colo_attr = ShardedParamV2(param, rm_torch_payload=False) + param.colo_attr = ShardedParamV2(param, set_data_none=False) if self.shard_param: self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) - param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload + param.data = param.colo_attr.data_payload # set param.data to payload # mark whether the param is replicated param.colo_attr.is_replicated = self.is_replicated diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 5c5c2c421..87564f9b8 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -260,7 +260,7 @@ class ShardedModelV2(nn.Module): if not p.colo_attr.param_is_sharded: tensor_list.append(p.colo_attr.sharded_data_tensor) p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) - p.colo_attr.remove_torch_payload() + p.colo_attr.set_data_none() self.shard_strategy.shard(tensor_list, self.process_group) # 4. set all parameters' grad to None @@ -357,8 +357,8 @@ class ShardedModelV2(nn.Module): assert param.colo_attr.saved_grad.is_null( ), 'Gradien accumulation is not supported when reuse_fp16_shard=True' - param.colo_attr.saved_grad.reset_payload(grad) - param.colo_attr.sharded_data_tensor.reset_payload(grad) # release the memory of param + param.colo_attr.reset_grad_payload(grad) + param.colo_attr.reset_grad_payload(grad) # release the memory of param if param.colo_attr.is_replicated: param.colo_attr.sharded_data_tensor.is_sharded = True @@ -367,9 +367,9 @@ class ShardedModelV2(nn.Module): fp32_grad = cast_tensor_to_fp32(grad) if param.colo_attr.saved_grad.is_null(): - param.colo_attr.saved_grad.reset_payload(fp32_grad) + param.colo_attr.reset_grad_payload(fp32_grad) else: - param.colo_attr.saved_grad.payload.add_(fp32_grad.view_as(param.colo_attr.saved_grad.payload)) + param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload)) # keep saved_grad in HOLD state param.colo_attr.saved_grad.trans_state(TensorState.HOLD) @@ -377,11 +377,11 @@ class ShardedModelV2(nn.Module): def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group) for p in self.sharded_params: - p.data = p.colo_attr.sharded_data_tensor.payload + p.data = p.colo_attr.data_payload gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars) self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group) for p in self.sharded_params: - p.colo_attr.remove_torch_payload() + p.colo_attr.set_data_none() return gathered_state_dict def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): diff --git a/colossalai/zero/sharded_model/utils.py b/colossalai/zero/sharded_model/utils.py index 9777e0f63..69f5a23ac 100644 --- a/colossalai/zero/sharded_model/utils.py +++ b/colossalai/zero/sharded_model/utils.py @@ -14,6 +14,6 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded if shard_flag: sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor]) - param.data = copy.deepcopy(zero_param.colo_attr.sharded_data_tensor.payload) + param.data = copy.deepcopy(zero_param.colo_attr.data_payload) if shard_flag: sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor]) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 88464b0e1..087d2c1f5 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -266,8 +266,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): if shard_flag: # we always shard replicated paramters self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) - self.master_params[p] = StatefulTensor( - cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device))) + self.master_params[p] = StatefulTensor(cast_tensor_to_fp32(p.colo_attr.data_payload.to(self.device))) if shard_flag: # In this branch, there's no need to shard param # So we gather here @@ -296,10 +295,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # If we change p.grad directly # it may raise error because of different shape/dtype/device of p.data and p.grad # We just set p.data = p.colo_attr.saved_grad.payload here - p.data = p.colo_attr.saved_grad.payload - p.grad = p.colo_attr.saved_grad.payload + p.data = p.colo_attr.grad_payload + p.grad = p.colo_attr.grad_payload # Set p.data to empty tensor, in case of memory leaking - p.colo_attr.remove_torch_payload() + p.colo_attr.set_data_none() def _point_param_fp16_to_master_param(self): # assign master param pointers to p.data. @@ -325,9 +324,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # TODO() optimize this line CPU (fp32) -> GPU (fp16) p.data = self.master_params[p].payload - p.colo_attr.sharded_data_tensor.reset_payload( - colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device)) - p.colo_attr.remove_torch_payload() + p.colo_attr.reset_data_payload( + colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device)) + p.colo_attr.set_data_none() if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated: # We gather full fp16 param here diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 51c3d8556..87a09df44 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -10,10 +10,20 @@ from typing import List # empty tensor is expected to raise error when get used FAKE_EMPTY_TENSOR = torch.BoolTensor([], device='cpu') +EMPTY_TENSOR_DICT = {} + + +def get_empty_tensor(device: torch.device, dtype: torch.dtype): + key = (device, dtype) + if key not in EMPTY_TENSOR_DICT: + EMPTY_TENSOR_DICT[key] = FAKE_EMPTY_TENSOR.to(device, dtype) + + return EMPTY_TENSOR_DICT[key] + class ShardedParamV2(object): - def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None: + def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None: self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data) self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) # This attribute must be initialized in ShardedModel @@ -25,24 +35,47 @@ class ShardedParamV2(object): # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # So we can not empty the .data at this time self.param = param - if rm_torch_payload: - self.remove_torch_payload() + if set_data_none: + self.set_data_none() def get_payload_tensors(self) -> List[StatefulTensor]: """returns stateful tensors kept by this class. """ return [self._sharded_data_tensor] - def remove_torch_payload(self): - self.param.data = FAKE_EMPTY_TENSOR.to(self._sharded_data_tensor.device, self._sharded_data_tensor.dtype) + def set_data_none(self): + self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype) + + def set_grad_none(self): + self.saved_grad.set_null() @property def sharded_data_tensor(self): return self._sharded_data_tensor + @property + def data_payload(self): + return self.sharded_data_tensor.payload + + @property + def grad_payload(self): + assert not self.saved_grad.is_null() + return self.saved_grad.payload + @property def param_is_sharded(self): - return self._sharded_data_tensor.is_sharded + return self.sharded_data_tensor.is_sharded + + def reset_data_payload(self, tensor: torch.Tensor): + assert type(tensor) is torch.Tensor + assert tensor.requires_grad is False + self.sharded_data_tensor.reset_payload(tensor) + self.set_data_none() + + def reset_grad_payload(self, tensor: torch.Tensor): + assert type(tensor) is torch.Tensor + assert tensor.requires_grad is False + self.saved_grad.reset_payload(tensor) def get_memory_usage(self) -> Tuple[int, int]: """ @@ -63,11 +96,11 @@ class ShardedParamV2(object): cpu_mem_use += t_cpu address_set = set() - _update_mem_use(self.sharded_data_tensor.payload) - address_set.add(self.sharded_data_tensor.payload.data_ptr()) + _update_mem_use(self.data_payload) + address_set.add(self.data_payload.data_ptr()) if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set: - _update_mem_use(self.saved_grad.payload) + _update_mem_use(self.grad_payload) address_set.add(self.saved_grad.data_ptr()) if self.param.data is not None and self.param.data.data_ptr() not in address_set: diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py index 8e799c314..fde273320 100644 --- a/colossalai/zero/sharded_param/sharded_tensor.py +++ b/colossalai/zero/sharded_param/sharded_tensor.py @@ -9,6 +9,7 @@ class ShardedTensor(StatefulTensor): r""" A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance. """ + assert tensor.requires_grad is False super().__init__(tensor, state) # kept the shape, numel and dtype of the init tensor. @@ -17,6 +18,11 @@ class ShardedTensor(StatefulTensor): self._origin_dtype = tensor.dtype self._is_sharded = False + @property + def dtype(self) -> torch.dtype: + assert self._payload.dtype == self._origin_dtype + return self._payload.dtype + @property def origin_numel(self) -> int: return self._origin_numel diff --git a/colossalai/zero/sharded_param/tensorful_state.py b/colossalai/zero/sharded_param/tensorful_state.py index 5bde388be..a108963e5 100644 --- a/colossalai/zero/sharded_param/tensorful_state.py +++ b/colossalai/zero/sharded_param/tensorful_state.py @@ -19,11 +19,11 @@ class StatefulTensor(object): https://arxiv.org/abs/2108.05818 """ - def __init__(self, tensor: torch.Tensor, state: Optional[TensorState] = TensorState.HOLD) -> None: + def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None: self._state = state self._payload = tensor if self._state == TensorState.FREE: - assert self._payload is None, f"payload has to None if {self._state}" + assert self._payload is None, f"payload has to None if state is {self._state}" def data_ptr(self): if self._payload is None: @@ -50,13 +50,13 @@ class StatefulTensor(object): self._payload = None @property - def payload(self) -> int: + def payload(self) -> Optional[torch.Tensor]: return self._payload - def copy_payload(self, tensor) -> int: + def copy_payload(self, tensor) -> None: self._payload.view(-1).copy_(tensor.view(-1)) - def reset_payload(self, tensor) -> int: + def reset_payload(self, tensor) -> None: del self._payload self._payload = tensor self.trans_state(TensorState.HOLD) @@ -67,15 +67,14 @@ class StatefulTensor(object): @property def dtype(self) -> torch.dtype: - assert self._payload.dtype == self._origin_dtype - return self._origin_dtype + return self._payload.dtype + + @property + def shape(self): + return self._payload.shape def to(self, device: torch.device): raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor") def to_(self, device: torch.device): raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor") - - @property - def shape(self): - return self._payload.shape diff --git a/colossalai/zero/utils/zero_hook.py b/colossalai/zero/utils/zero_hook.py index 14e502530..34d7e0d5a 100644 --- a/colossalai/zero/utils/zero_hook.py +++ b/colossalai/zero/utils/zero_hook.py @@ -60,7 +60,7 @@ class ZeroHook(BaseOpHook): self._memstarts_collector.sample_memstats() for param in module.parameters(recurse=False): - param.data = param.colo_attr.sharded_data_tensor.payload + param.data = param.colo_attr.data_payload assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA" def post_fwd_exec(self, module: torch.nn.Module, *args): @@ -79,7 +79,7 @@ class ZeroHook(BaseOpHook): # remove torch payload for param in module.parameters(recurse=False): - param.colo_attr.remove_torch_payload() + param.colo_attr.set_data_none() def pre_bwd_exec(self, module: torch.nn.Module, input, output): @@ -105,7 +105,7 @@ class ZeroHook(BaseOpHook): self._memstarts_collector.sample_memstats() for param in module.parameters(recurse=False): - param.data = param.colo_attr.sharded_data_tensor.payload + param.data = param.colo_attr.data_payload assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA" def post_bwd_exec(self, module: torch.nn.Module, input): @@ -124,7 +124,7 @@ class ZeroHook(BaseOpHook): # remove torch payload for param in module.parameters(recurse=False): - param.colo_attr.remove_torch_payload() + param.colo_attr.set_data_none() def pre_iter(self): pass diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 7c5f30b14..50963e641 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -77,10 +77,10 @@ def run_moe_zero_init(init_device_type, shard_strategy_class): assert param.colo_attr.is_replicated if param.colo_attr.param_is_sharded: - assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \ - f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' + assert param.colo_attr.data_payload.device.type == init_device.type, \ + f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' else: - assert param.colo_attr.sharded_data_tensor.payload.device.type == 'cuda' + assert param.colo_attr.data_payload.device.type == 'cuda' def _run_dist(rank, world_size, port): diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index af78c032c..2e3b620cf 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -37,7 +37,7 @@ def run_model_test(enable_autocast, shard_strategy_class): # check whether parameters are identical in ddp for name, p in zero_model.named_parameters(): if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: - assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload) + assert_equal_in_group(p.colo_attr.data_payload) model = MoeModel(checkpoint=True).half() col_model_deepcopy(zero_model, model) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 1ae25997b..cb39b8d7b 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -76,7 +76,7 @@ def _run_test_sharded_optim_v2(cpu_offload, # check whether parameters are identical in ddp for name, p in zero_model.named_parameters(): if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: - assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device())) + assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device())) model = MoeModel(checkpoint=True).half() col_model_deepcopy(zero_model, model) @@ -100,7 +100,7 @@ def _run_test_sharded_optim_v2(cpu_offload, for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()): if 'gate' in n: p.data = p.float() - p.data.copy_(zp.colo_attr.sharded_data_tensor.payload) + p.data.copy_(zp.colo_attr.data_payload) for i, (data, label) in enumerate(train_dataloader): if i > 5: diff --git a/tests/test_zero/common.py b/tests/test_zero/common.py index 993fed98e..d495cf018 100644 --- a/tests/test_zero/common.py +++ b/tests/test_zero/common.py @@ -94,7 +94,7 @@ def check_grads_padding(model, zero_model, loose=False): for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): # zero_grad = zero_p.grad.clone().to(p.device) if zero_p.colo_attr.is_replicated: - zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device) + zero_grad = zero_p.colo_attr.grad_payload.clone().to(p.device) chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) if rank >= len(chunks): continue @@ -102,7 +102,7 @@ def check_grads_padding(model, zero_model, loose=False): if zero_grad.size(0) > grad.size(0): zero_grad = zero_grad[:grad.size(0)] else: - zero_grad = zero_p.colo_attr.saved_grad.payload + zero_grad = zero_p.colo_attr.grad_payload grad = p.grad.to(zero_grad.dtype) assert grad.dtype == zero_grad.dtype @@ -127,7 +127,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard= rank = dist.get_rank() for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): if zero_p.colo_attr.param_is_sharded: - zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float() + zero_p = zero_p.colo_attr.data_payload.to(p.device).float() chunks = torch.flatten(p).chunk(dist.get_world_size()) if rank >= len(chunks): continue @@ -135,7 +135,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard= if zero_p.size(0) > p.size(0): zero_p = zero_p[:p.size(0)] else: - zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device) + zero_p = zero_p.colo_attr.data_payload.to(p.device) assert p.dtype == zero_p.dtype assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' diff --git a/tests/test_zero/test_found_inf.py b/tests/test_zero/test_found_inf.py index fc02bb67d..45bdd6e01 100644 --- a/tests/test_zero/test_found_inf.py +++ b/tests/test_zero/test_found_inf.py @@ -55,7 +55,7 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio) data, label = data.cuda(), label.cuda() _run_step(zero_model, sharded_optim, data, label, criterion, False) for param in zero_model.parameters(): - assert not has_inf_or_nan(param.colo_attr.sharded_data_tensor.payload) + assert not has_inf_or_nan(param.colo_attr.data_payload) def _run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_init_context.py b/tests/test_zero/test_init_context.py index 34777c6b8..cbbc6a7f3 100644 --- a/tests/test_zero/test_init_context.py +++ b/tests/test_zero/test_init_context.py @@ -46,8 +46,8 @@ def run_model_test(init_device_type, shard_strategy_class): assert hasattr(param, 'colo_attr') assert param.colo_attr.sharded_data_tensor.dtype == torch.half assert param.colo_attr.sharded_data_tensor.is_sharded - assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \ - f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' + assert param.colo_attr.data_payload.device.type == init_device.type, \ + f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' cuda_mem_use, _ = colo_model_mem_usage(model) model_data_cuda_mem_MB = cuda_mem_use / 1e6 diff --git a/tests/test_zero/test_shard_param.py b/tests/test_zero/test_shard_param.py index e88a64c35..91c669af3 100644 --- a/tests/test_zero/test_shard_param.py +++ b/tests/test_zero/test_shard_param.py @@ -50,27 +50,27 @@ def _run_shard_param_v2(rank, world_size, port): param_ref = deepcopy(param) sparam = ShardedParamV2(param=param) - allclose(sparam.sharded_data_tensor.payload, param_ref.data) + allclose(sparam.data_payload, param_ref.data) # Test get memory usage sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}" - sparam.remove_torch_payload() + sparam.set_data_none() assert (param.data.numel() == 0) cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() # 4 is size of dummy tensor of param.data assert cpu_mem_use == 2 * 3 * 4 * 2 sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) - sparam.remove_torch_payload() + sparam.set_data_none() cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() assert cpu_mem_use == 2 * 3 * 4 * 2 assert cuda_mem_use == 0 # append a grad to torch param - param.data = sparam.sharded_data_tensor.payload + param.data = sparam.data_payload param.grad = torch.randn(2, 3) cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}" diff --git a/tests/test_zero/test_stateful_tensor_mgr.py b/tests/test_zero/test_stateful_tensor_mgr.py index ed76d27a7..bc8475914 100644 --- a/tests/test_zero/test_stateful_tensor_mgr.py +++ b/tests/test_zero/test_stateful_tensor_mgr.py @@ -34,7 +34,7 @@ def run_stm(): colo_set_process_memory_fraction(fraction) model = Net() for p in model.parameters(): - p.colo_attr = ShardedParamV2(p, rm_torch_payload=True) + p.colo_attr = ShardedParamV2(p, set_data_none=True) GLOBAL_MODEL_DATA_TRACER.register_model(model) mem_collector = MemStatsCollector() stateful_tensor_mgr = StatefulTensorMgr(mem_collector)