[zero] sharded model support the reuse of fp16 shard (#495)

* sharded model supports reuse fp16 shard

* rename variable

* polish code

* polish code

* polish code
This commit is contained in:
ver217
2022-03-23 14:59:59 +08:00
committed by GitHub
parent f24b5ed201
commit 9ec1ce6ab1
7 changed files with 62 additions and 42 deletions

View File

@@ -16,7 +16,8 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
offload_config=None,
gradient_predivide_factor=1.0,
use_memory_tracer=False,
shard_strategy=TensorShardStrategy())
shard_strategy=TensorShardStrategy(),
reuse_fp16_shard=False)
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
initial_scale=2**5,
@@ -116,10 +117,13 @@ def check_params_padding(model, zero_model, loose=False):
assert allclose(p, zero_p, loose=loose)
def check_sharded_params_padding(model, zero_model, loose=False):
def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
if reuse_fp16_shard:
zero_p = zero_p.data.to(p.device).float()
else:
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue