[gemini] update the gpt example (#2527)

This commit is contained in:
HELSON
2023-01-30 17:58:05 +08:00
committed by GitHub
parent ecbad93b65
commit 66dfcf5281
4 changed files with 75 additions and 98 deletions

View File

@@ -32,16 +32,19 @@ def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Opt
>>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto')
>>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict)
"""
setattr(model, "_colo_zero_stage", zero_stage)
assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3"
if gemini_config is None:
gemini_config = dict()
if zero_stage in [1, 2]:
return model
wrapped_model = model
else:
return GeminiDDP(model, **gemini_config)
wrapped_model = GeminiDDP(model, **gemini_config)
setattr(wrapped_model, "_colo_zero_stage", zero_stage)
return wrapped_model
def zero_optim_wrapper(model: nn.Module,