mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[test] merge old components to test to model zoo (#4945)
* [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test
This commit is contained in:
@@ -7,7 +7,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import set_seed
|
||||
from colossalai.zero import GeminiDDP
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
|
||||
@@ -26,15 +26,16 @@ def ignore_the_first_parameter(model: torch.nn.Module):
|
||||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("keep_gathered", [True, False])
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
@parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"])
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
|
||||
model = model_builder()
|
||||
|
||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
|
||||
torch_model = model_builder()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
@@ -54,29 +55,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
|
||||
|
||||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("keep_gathered", [True, False])
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
model = model_builder()
|
||||
|
||||
set_seed(451)
|
||||
torch_model = model_builder() # get a different model
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = keep_gathered
|
||||
|
||||
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
|
||||
|
||||
torch_dict = torch_model.state_dict()
|
||||
# check load state dict
|
||||
model.load_state_dict(torch_dict, strict=False)
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
|
||||
@@ -85,23 +64,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str, maste
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
|
||||
|
||||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
model = model_builder()
|
||||
|
||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
# check state dict shard
|
||||
accumulated_keys = set()
|
||||
# ensure number of shards > 1
|
||||
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
|
||||
@@ -116,8 +79,6 @@ def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
exam_state_dict()
|
||||
exam_load_state_dict()
|
||||
exam_state_dict_shard()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
Reference in New Issue
Block a user