mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
Merge branch 'main' into feature/shardformer
This commit is contained in:
@@ -18,12 +18,45 @@ from colossalai.testing import (
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
MODEL_PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 1.0
|
||||
}, # zero3
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.5
|
||||
}, # zero3-half
|
||||
]
|
||||
|
||||
OPTIM_PLACEMENT_CONFIGS = [
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.0
|
||||
}, # zero2
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 1.0
|
||||
}, # zero2-offload
|
||||
{
|
||||
'placement_policy': 'static',
|
||||
'shard_param_frac': 0.0,
|
||||
'offload_optim_frac': 0.5
|
||||
}, # zero2-offload-half
|
||||
]
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('placement_config', MODEL_PLACEMENT_CONFIGS)
|
||||
@parameterize('model_name', ['transformers_bert_for_sequence_classification'])
|
||||
@parameterize('use_safetensors', [False, True])
|
||||
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
|
||||
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool):
|
||||
from transformers import BertForSequenceClassification
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
bert_model = model_fn()
|
||||
@@ -32,7 +65,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
|
||||
pretrained_path = os.path.join(tempdir, 'pretrained')
|
||||
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
||||
|
||||
plugin = GeminiPlugin(placement_policy=placement_policy)
|
||||
plugin = GeminiPlugin(**placement_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
bert_model, _, _, _, _ = booster.boost(bert_model)
|
||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||
@@ -46,19 +79,19 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
|
||||
dist.barrier()
|
||||
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
|
||||
check_state_dict_equal(bert_model.unwrap().state_dict(only_rank_0=False, dtype=torch.float32),
|
||||
check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32),
|
||||
new_bert_model.state_dict(), False)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('placement_config', OPTIM_PLACEMENT_CONFIGS)
|
||||
@parameterize('shard', [False, True])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
@parameterize('size_per_shard', [32])
|
||||
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
|
||||
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14))
|
||||
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14))
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = model_fn()
|
||||
@@ -87,12 +120,11 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_sha
|
||||
dist.barrier()
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
|
||||
new_model.unwrap().state_dict(only_rank_0=False), False)
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False),
|
||||
new_optimizer.unwrap().state_dict(only_rank_0=False), False)
|
||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False),
|
||||
False)
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
|
@@ -60,12 +60,11 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
||||
new_booster.load_model(new_model, model_ckpt_path, strict=True)
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
new_model.state_dict(), False)
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
new_model.state_dict(), False)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
|
||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
@@ -124,13 +123,12 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
||||
new_booster.load_model(new_model, model_ckpt_path, strict=True)
|
||||
|
||||
# Add prefix to get aligned with pytorch parameter names.
|
||||
check_state_dict_equal(
|
||||
new_model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
model.state_dict(), False)
|
||||
check_state_dict_equal(new_model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
|
||||
model.state_dict(), False)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
old_state_dict = optimizer.state_dict()
|
||||
new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False)
|
||||
new_state_dict = new_optimizer.state_dict(only_rank_0=False)
|
||||
|
||||
# Comparison of param_groups needs special care here,
|
||||
# since not all hyperparameters in Adam are used by HybridAdam
|
||||
@@ -138,7 +136,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
||||
for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']):
|
||||
for k in hyperparameters_to_examine:
|
||||
assert k in old_group and k in new_group, \
|
||||
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
|
||||
f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
|
||||
assert old_group[k] == new_group[k]
|
||||
check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False)
|
||||
|
||||
|
@@ -16,19 +16,21 @@ from colossalai.testing import (
|
||||
)
|
||||
|
||||
|
||||
# stage 1 and 2 process the optimizer/mode the same way
|
||||
# only test 2 is fine
|
||||
@clear_cache_before_run()
|
||||
@parameterize('stage', [2])
|
||||
@parameterize('shard', [True, False])
|
||||
def check_low_level_zero_checkpointIO(stage: int, shard: bool):
|
||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32)
|
||||
@parameterize('offload', [False, True])
|
||||
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
|
||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
criterion = lambda x: x.mean()
|
||||
optimizer = HybridAdam((model.parameters()), lr=0.001)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
x = torch.randn(4, 3, 224, 224)
|
||||
x = x.to('cuda')
|
||||
x = torch.randn(1, 3, 224, 224, device='cuda')
|
||||
output = model(x)
|
||||
loss = criterion(output)
|
||||
booster.backward(loss, optimizer)
|
||||
@@ -50,15 +52,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_low_level_zero_checkpointIO()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_low_level_zero_checkpointIO():
|
||||
spawn(run_dist, 2)
|
||||
|
||||
|
Reference in New Issue
Block a user