[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)

* fix master param sync for hybrid plugin

* rewrite unwrap for ddp/fsdp

* rewrite unwrap for zero/gemini

* rewrite unwrap for hybrid plugin

* fix geemini unwrap

* fix bugs
This commit is contained in:
Baizhou Zhang
2023-09-20 18:29:37 +08:00
committed by GitHub
parent 7b9b86441f
commit c0a033700c
14 changed files with 141 additions and 171 deletions

View File

@@ -186,10 +186,6 @@ class GeminiDDP(ModelWrapper):
for p in params_to_ignore:
p._ddp_to_ignore = True
def unwrap(self):
# as save/load state dict is overwrited, only return self
return self
def _get_non_persistent_buffers_set(
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):