[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -1,12 +1,17 @@
import pytest
import torch
import torch.nn as nn
from torch.fx import symbolic_trace
import colossalai
import colossalai.nn as col_nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass, \
uniform_split_pass, balanced_split_pass_v2
import pytest
from colossalai.fx.passes.adding_split_node_pass import (
balanced_split_pass,
balanced_split_pass_v2,
split_with_split_nodes_pass,
uniform_split_pass,
)
from colossalai.testing import clear_cache_before_run
MODEL_DIM = 16
BATCH_SIZE = 8
@@ -39,6 +44,7 @@ def pipeline_pass_test_helper(model, data, pass_func):
assert output.equal(origin_output)
@clear_cache_before_run()
def test_pipeline_passes():
model = MLP(MODEL_DIM)
data = torch.rand(BATCH_SIZE, MODEL_DIM)