[gemini] support amp o3 for gemini (#4872)

* [gemini] support no reuse fp16 chunk

* [gemini] support no master weight for optim

* [gemini] support no master weight for gemini ddp

* [test] update gemini tests

* [test] update gemini tests

* [plugin] update gemini plugin

* [test] fix gemini checkpointio test

* [test] fix gemini checkpoint io
This commit is contained in:
Hongxin Liu
2023-10-12 10:39:08 +08:00
committed by GitHub
parent c1fab951e7
commit df63564184
15 changed files with 222 additions and 114 deletions

View File

@@ -58,9 +58,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(
bert_model.state_dict(only_rank_0=False, dtype=torch.float32), new_bert_model.state_dict(), False
)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False)
@clear_cache_before_run()
@@ -100,7 +98,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
dist.barrier()
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
check_state_dict_equal(
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True
)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(
@@ -136,7 +136,7 @@ def exam_lazy_from_pretrained():
booster.save_model(model, save_path, shard=False)
dist.barrier()
state_dict = torch.load(save_path, map_location="cpu")
check_state_dict_equal(state_dict, orig_state_dict, False)
check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True)
def run_dist(rank, world_size, port):

View File

@@ -60,9 +60,10 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
# Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal(
model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32),
model.state_dict(only_rank_0=False, prefix="module.module."),
new_model.state_dict(),
False,
ignore_dtype=True,
)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
@@ -125,9 +126,10 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
# Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal(
new_model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32),
new_model.state_dict(only_rank_0=False, prefix="module.module."),
model.state_dict(),
False,
ignore_dtype=True,
)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)

View File

@@ -27,6 +27,8 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager
param_list = [p for p in model.parameters()]
chunk_list = chunk_manager.get_chunks(param_list)
if not model.reuse_fp16_chunk:
chunk_list = [chunk.grad_chunk for chunk in chunk_list]
for chunk in chunk_list:
chunk_manager.access_chunk(chunk)
@@ -36,13 +38,15 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gather", [False, True])
@parameterize("model_name", ["gpt2", "bert", "albert"])
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True])
def exam_gpt_fwd_bwd(
placement_config,
keep_gather,
model_name: str,
use_grad_checkpoint: bool = False,
master_weights: bool = True,
):
init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
@@ -60,12 +64,14 @@ def exam_gpt_fwd_bwd(
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gather
model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config)
model = GeminiDDP(
model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights
)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
rank = dist.get_rank()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1)
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, master_weights=master_weights)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[rank])
@@ -106,4 +112,4 @@ def test_gpt(world_size):
if __name__ == "__main__":
test_gpt(4)
test_gpt(1)

View File

@@ -78,7 +78,11 @@ def exam_grad_clipping(placement_config, model_name: str):
init_device = None
model = GeminiDDP(
model, chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, **placement_config
model,
chunk_config_dict=config_dict,
chunk_init_device=init_device,
pin_memory=True,
**placement_config,
)
optimizer = HybridAdam(model.parameters(), lr=1e-3)

View File

@@ -44,7 +44,7 @@ BF16_IGNORED_KEYS = [
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
zero_dict = model.state_dict(only_rank_0=False, dtype=dtype)
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():

View File

@@ -27,7 +27,8 @@ def ignore_the_first_parameter(model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [True, False])
@parameterize("model_name", ["gpt2", "bert"])
def exam_state_dict(placement_config, keep_gathered, model_name: str):
@parameterize("master_weights", [False, True])
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@@ -42,7 +43,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str):
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gathered
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
@@ -57,7 +58,8 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [True, False])
@parameterize("model_name", ["gpt2", "bert"])
def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
@parameterize("master_weights", [False, True])
def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@@ -72,7 +74,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gathered
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
torch_dict = torch_model.state_dict()
model.load_state_dict(torch_dict, strict=False)
@@ -86,7 +88,8 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2", "bert"])
def exam_state_dict_shard(placement_config, model_name: str):
@parameterize("master_weights", [False, True])
def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@@ -95,7 +98,7 @@ def exam_state_dict_shard(placement_config, model_name: str):
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
model = GeminiDDP(model, config_dict, **placement_config)
model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights)
model.train()
zero_dict = model.state_dict(only_rank_0=False)