mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-25 01:40:08 +00:00
[zero] sharded model support the reuse of fp16 shard (#495)
* sharded model supports reuse fp16 shard * rename variable * polish code * polish code * polish code
This commit is contained in:
@@ -16,7 +16,8 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
|
||||
offload_config=None,
|
||||
gradient_predivide_factor=1.0,
|
||||
use_memory_tracer=False,
|
||||
shard_strategy=TensorShardStrategy())
|
||||
shard_strategy=TensorShardStrategy(),
|
||||
reuse_fp16_shard=False)
|
||||
|
||||
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
|
||||
initial_scale=2**5,
|
||||
@@ -116,10 +117,13 @@ def check_params_padding(model, zero_model, loose=False):
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def check_sharded_params_padding(model, zero_model, loose=False):
|
||||
def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
if reuse_fp16_shard:
|
||||
zero_p = zero_p.data.to(p.device).float()
|
||||
else:
|
||||
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user