mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
This commit is contained in:
@@ -60,12 +60,11 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
||||
new_booster.load_model(new_model, model_ckpt_path, strict=True)
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
new_model.state_dict(), False)
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
new_model.state_dict(), False)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
|
||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
@@ -124,13 +123,12 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
||||
new_booster.load_model(new_model, model_ckpt_path, strict=True)
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
new_model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
model.state_dict(), False)
|
||||
check_state_dict_equal(new_model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
model.state_dict(), False)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
old_state_dict = optimizer.state_dict()
|
||||
new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False)
|
||||
new_state_dict = new_optimizer.state_dict(only_rank_0=False)
|
||||
|
||||
# Comparison of param_groups needs special care here,
|
||||
# since not all hyperparameters in Adam are used by HybridAdam
|
||||
@@ -138,7 +136,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
||||
for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']):
|
||||
for k in hyperparameters_to_examine:
|
||||
assert k in old_group and k in new_group, \
|
||||
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
|
||||
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
|
||||
assert old_group[k] == new_group[k]
|
||||
check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False)
|
||||
|
||||
|
Reference in New Issue
Block a user