mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[gemini] support amp o3 for gemini (#4872)
* [gemini] support no reuse fp16 chunk * [gemini] support no master weight for optim * [gemini] support no master weight for gemini ddp * [test] update gemini tests * [test] update gemini tests * [plugin] update gemini plugin * [test] fix gemini checkpointio test * [test] fix gemini checkpoint io
This commit is contained in:
@@ -58,9 +58,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
dist.barrier()
|
||||
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
|
||||
check_state_dict_equal(
|
||||
bert_model.state_dict(only_rank_0=False, dtype=torch.float32), new_bert_model.state_dict(), False
|
||||
)
|
||||
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@@ -100,7 +98,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||
dist.barrier()
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
|
||||
check_state_dict_equal(
|
||||
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True
|
||||
)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(
|
||||
@@ -136,7 +136,7 @@ def exam_lazy_from_pretrained():
|
||||
booster.save_model(model, save_path, shard=False)
|
||||
dist.barrier()
|
||||
state_dict = torch.load(save_path, map_location="cpu")
|
||||
check_state_dict_equal(state_dict, orig_state_dict, False)
|
||||
check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
@@ -60,9 +60,10 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32),
|
||||
model.state_dict(only_rank_0=False, prefix="module.module."),
|
||||
new_model.state_dict(),
|
||||
False,
|
||||
ignore_dtype=True,
|
||||
)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
@@ -125,9 +126,10 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
new_model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32),
|
||||
new_model.state_dict(only_rank_0=False, prefix="module.module."),
|
||||
model.state_dict(),
|
||||
False,
|
||||
ignore_dtype=True,
|
||||
)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
|
@@ -27,6 +27,8 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
chunk_manager = model.chunk_manager
|
||||
param_list = [p for p in model.parameters()]
|
||||
chunk_list = chunk_manager.get_chunks(param_list)
|
||||
if not model.reuse_fp16_chunk:
|
||||
chunk_list = [chunk.grad_chunk for chunk in chunk_list]
|
||||
for chunk in chunk_list:
|
||||
chunk_manager.access_chunk(chunk)
|
||||
|
||||
@@ -36,13 +38,15 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("keep_gather", [False, True])
|
||||
@parameterize("model_name", ["gpt2", "bert", "albert"])
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
@parameterize("use_grad_checkpoint", [False, True])
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_gpt_fwd_bwd(
|
||||
placement_config,
|
||||
keep_gather,
|
||||
model_name: str,
|
||||
use_grad_checkpoint: bool = False,
|
||||
master_weights: bool = True,
|
||||
):
|
||||
init_device = get_current_device()
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
@@ -60,12 +64,14 @@ def exam_gpt_fwd_bwd(
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = keep_gather
|
||||
model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config)
|
||||
model = GeminiDDP(
|
||||
model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights
|
||||
)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
|
||||
|
||||
rank = dist.get_rank()
|
||||
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1)
|
||||
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, master_weights=master_weights)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[rank])
|
||||
@@ -106,4 +112,4 @@ def test_gpt(world_size):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gpt(4)
|
||||
test_gpt(1)
|
||||
|
@@ -78,7 +78,11 @@ def exam_grad_clipping(placement_config, model_name: str):
|
||||
init_device = None
|
||||
|
||||
model = GeminiDDP(
|
||||
model, chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, **placement_config
|
||||
model,
|
||||
chunk_config_dict=config_dict,
|
||||
chunk_init_device=init_device,
|
||||
pin_memory=True,
|
||||
**placement_config,
|
||||
)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
|
@@ -44,7 +44,7 @@ BF16_IGNORED_KEYS = [
|
||||
|
||||
|
||||
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
|
||||
zero_dict = model.state_dict(only_rank_0=False, dtype=dtype)
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
|
@@ -27,7 +27,8 @@ def ignore_the_first_parameter(model: torch.nn.Module):
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("keep_gathered", [True, False])
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
def exam_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -42,7 +43,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = keep_gathered
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
@@ -57,7 +58,8 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("keep_gathered", [True, False])
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -72,7 +74,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = keep_gathered
|
||||
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
|
||||
|
||||
torch_dict = torch_model.state_dict()
|
||||
model.load_state_dict(torch_dict, strict=False)
|
||||
@@ -86,7 +88,8 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
|
||||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
def exam_state_dict_shard(placement_config, model_name: str):
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
@@ -95,7 +98,7 @@ def exam_state_dict_shard(placement_config, model_name: str):
|
||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
model = GeminiDDP(model, config_dict, **placement_config)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
|
Reference in New Issue
Block a user