mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -2,6 +2,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
BATCH = 2
|
||||
@@ -47,6 +48,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def test_torchrec_deepfm_models():
|
||||
deepfm_models = model_zoo.get_sub_registry('deepfm')
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
@@ -2,6 +2,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
BATCH = 2
|
||||
@@ -47,6 +48,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def test_torchrec_dlrm_models():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
dlrm_models = model_zoo.get_sub_registry('dlrm')
|
||||
|
Reference in New Issue
Block a user