This commit is contained in:
flybird11111 2025-05-26 16:05:28 +08:00 committed by GitHub
parent 552778fb20
commit 63dc73d478
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,7 +10,7 @@ import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run
from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
@ -53,6 +53,7 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict):
return model
@rerun_if_address_is_in_use()
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("model_init_func", [single_chunk_init, multi_chunk_init])
@ -104,6 +105,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal
train_iter()
inference_iter()
train_iter()
torch.cuda.empty_cache()
def run_dist(rank, world_size, port):
@ -112,8 +114,8 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@clear_cache_before_run()
@pytest.mark.parametrize("world_size", [1, 4])
@rerun_if_address_is_in_use()
def test_inference(world_size):
spawn(run_dist, world_size)