mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[gemini] update the gpt example (#2527)
This commit is contained in:
@@ -32,16 +32,19 @@ def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Opt
|
||||
>>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto')
|
||||
>>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict)
|
||||
"""
|
||||
setattr(model, "_colo_zero_stage", zero_stage)
|
||||
assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3"
|
||||
|
||||
if gemini_config is None:
|
||||
gemini_config = dict()
|
||||
|
||||
if zero_stage in [1, 2]:
|
||||
return model
|
||||
wrapped_model = model
|
||||
else:
|
||||
return GeminiDDP(model, **gemini_config)
|
||||
wrapped_model = GeminiDDP(model, **gemini_config)
|
||||
|
||||
setattr(wrapped_model, "_colo_zero_stage", zero_stage)
|
||||
|
||||
return wrapped_model
|
||||
|
||||
|
||||
def zero_optim_wrapper(model: nn.Module,
|
||||
|
Reference in New Issue
Block a user