mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)
* fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs
This commit is contained in:
@@ -61,9 +61,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
|
||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
if plugin_type == "gemini":
|
||||
check_state_dict_equal(
|
||||
model.unwrap().state_dict(only_rank_0=False), new_model.unwrap().state_dict(only_rank_0=False), False
|
||||
)
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
|
||||
else:
|
||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
||||
dist.barrier()
|
||||
|
Reference in New Issue
Block a user