mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-29 00:38:02 +00:00
fix (#6328)
This commit is contained in:
parent
63dc73d478
commit
559f15a4c9
@ -54,6 +54,7 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict):
|
|||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||||
@parameterize("model_name", ["transformers_gpt_lm"])
|
@parameterize("model_name", ["transformers_gpt_lm"])
|
||||||
@parameterize("model_init_func", [single_chunk_init, multi_chunk_init])
|
@parameterize("model_init_func", [single_chunk_init, multi_chunk_init])
|
||||||
@ -114,7 +115,6 @@ def run_dist(rank, world_size, port):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@clear_cache_before_run()
|
|
||||||
@pytest.mark.parametrize("world_size", [1, 4])
|
@pytest.mark.parametrize("world_size", [1, 4])
|
||||||
def test_inference(world_size):
|
def test_inference(world_size):
|
||||||
spawn(run_dist, world_size)
|
spawn(run_dist, world_size)
|
||||||
|
Loading…
Reference in New Issue
Block a user