mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-07 11:03:58 +00:00
fix (#6327)
This commit is contained in:
parent
552778fb20
commit
63dc73d478
@ -10,7 +10,7 @@ import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run
|
||||
from colossalai.utils import set_seed
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
@ -53,6 +53,7 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict):
|
||||
return model
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("model_name", ["transformers_gpt_lm"])
|
||||
@parameterize("model_init_func", [single_chunk_init, multi_chunk_init])
|
||||
@ -104,6 +105,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal
|
||||
train_iter()
|
||||
inference_iter()
|
||||
train_iter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
@ -112,8 +114,8 @@ def run_dist(rank, world_size, port):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@clear_cache_before_run()
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_inference(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user