mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[zero] polish sharded param name (#484)
* [zero] polish sharded param name * polish code * polish * polish code * polish * polsih * polish
This commit is contained in:
@@ -34,13 +34,13 @@ class ZeroHook(BaseOpHook):
|
||||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
if param.col_attr.data.device != self.computing_device:
|
||||
param.col_attr.data.to(self.computing_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
|
||||
param.data = param.col_attr.data.payload
|
||||
if param.col_attr.sharded_data_tensor.device != self.computing_device:
|
||||
param.col_attr.sharded_data_tensor.to(self.computing_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
@@ -49,7 +49,7 @@ class ZeroHook(BaseOpHook):
|
||||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
param.col_attr.remove_torch_payload()
|
||||
@@ -58,13 +58,13 @@ class ZeroHook(BaseOpHook):
|
||||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
if param.col_attr.data.device != self.computing_device:
|
||||
param.col_attr.data.to(self.computing_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
|
||||
param.data = param.col_attr.data.payload
|
||||
if param.col_attr.sharded_data_tensor.device != self.computing_device:
|
||||
param.col_attr.sharded_data_tensor.to(self.computing_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
# Store local accumulated grad shard
|
||||
if param.grad is not None:
|
||||
if param.col_attr.bwd_count == 0:
|
||||
@@ -75,7 +75,7 @@ class ZeroHook(BaseOpHook):
|
||||
else:
|
||||
# We have stored local accumulated grad
|
||||
# The grad here must be locally computed full grad in this backward pass
|
||||
assert param.grad.shape == param.col_attr.data.origin_shape
|
||||
assert param.grad.shape == param.col_attr.sharded_data_tensor.origin_shape
|
||||
param.col_attr.bwd_count += 1
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
@@ -84,7 +84,7 @@ class ZeroHook(BaseOpHook):
|
||||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
param.col_attr.remove_torch_payload()
|
||||
|
Reference in New Issue
Block a user