This commit is contained in:
flybird11111 2025-05-26 18:10:57 +08:00 committed by GitHub
parent 63dc73d478
commit 559f15a4c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -54,6 +54,7 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("model_init_func", [single_chunk_init, multi_chunk_init])
@ -114,7 +115,6 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@clear_cache_before_run()
@pytest.mark.parametrize("world_size", [1, 4])
def test_inference(world_size):
spawn(run_dist, world_size)