[zero] polish sharded param name (#484)

* [zero] polish sharded param name

* polish code

* polish

* polish code

* polish

* polsih

* polish
This commit is contained in:
Jiarui Fang
2022-03-22 14:36:16 +08:00
committed by GitHub
parent 9caa8b6481
commit b334822163
12 changed files with 55 additions and 222 deletions

View File

@@ -34,13 +34,13 @@ class ZeroHook(BaseOpHook):
tensor_list = []
for param in module.parameters():
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data)
tensor_list.append(param.col_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters():
if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
param.data = param.col_attr.data.payload
if param.col_attr.sharded_data_tensor.device != self.computing_device:
param.col_attr.sharded_data_tensor.to(self.computing_device)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
param.data = param.col_attr.sharded_data_tensor.payload
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
@@ -49,7 +49,7 @@ class ZeroHook(BaseOpHook):
tensor_list = []
for param in module.parameters():
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data)
tensor_list.append(param.col_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters():
param.col_attr.remove_torch_payload()
@@ -58,13 +58,13 @@ class ZeroHook(BaseOpHook):
tensor_list = []
for param in module.parameters():
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data)
tensor_list.append(param.col_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters():
if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
param.data = param.col_attr.data.payload
if param.col_attr.sharded_data_tensor.device != self.computing_device:
param.col_attr.sharded_data_tensor.to(self.computing_device)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
param.data = param.col_attr.sharded_data_tensor.payload
# Store local accumulated grad shard
if param.grad is not None:
if param.col_attr.bwd_count == 0:
@@ -75,7 +75,7 @@ class ZeroHook(BaseOpHook):
else:
# We have stored local accumulated grad
# The grad here must be locally computed full grad in this backward pass
assert param.grad.shape == param.col_attr.data.origin_shape
assert param.grad.shape == param.col_attr.sharded_data_tensor.origin_shape
param.col_attr.bwd_count += 1
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
@@ -84,7 +84,7 @@ class ZeroHook(BaseOpHook):
tensor_list = []
for param in module.parameters():
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data)
tensor_list.append(param.col_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters():
param.col_attr.remove_torch_payload()