mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[gemini] support amp o3 for gemini (#4872)
* [gemini] support no reuse fp16 chunk * [gemini] support no master weight for optim * [gemini] support no master weight for gemini ddp * [test] update gemini tests * [test] update gemini tests * [plugin] update gemini plugin * [test] fix gemini checkpointio test * [test] fix gemini checkpoint io
This commit is contained in:
@@ -27,7 +27,8 @@ def ignore_the_first_parameter(model: torch.nn.Module):
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("keep_gathered", [True, False])
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
def exam_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -42,7 +43,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = keep_gathered
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
@@ -57,7 +58,8 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("keep_gathered", [True, False])
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -72,7 +74,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = keep_gathered
|
||||
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
|
||||
|
||||
torch_dict = torch_model.state_dict()
|
||||
model.load_state_dict(torch_dict, strict=False)
|
||||
@@ -86,7 +88,8 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
def exam_state_dict_shard(placement_config, model_name: str):
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
@@ -95,7 +98,7 @@ def exam_state_dict_shard(placement_config, model_name: str):
|
||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
model = GeminiDDP(model, config_dict, **placement_config)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
|
Reference in New Issue
Block a user