[fx] add balanced policy v2 (#1251)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx] add balanced policy v2

* add unittest
This commit is contained in:
YuliangLiu0306
2022-07-15 14:54:26 +08:00
committed by GitHub
parent ca2d3f284f
commit e8acf55e8b
3 changed files with 54 additions and 3 deletions

View File

@@ -4,7 +4,8 @@ import colossalai
import colossalai.nn as col_nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass, \
uniform_split_pass
uniform_split_pass, balanced_split_pass_v2
import pytest
MODEL_DIM = 16
@@ -43,6 +44,7 @@ def test_pipeline_passes():
model = MLP(MODEL_DIM)
data = torch.rand(BATCH_SIZE, MODEL_DIM)
pipeline_pass_test_helper(model, data, balanced_split_pass)
pipeline_pass_test_helper(model, data, balanced_split_pass_v2)
pipeline_pass_test_helper(model, data, uniform_split_pass)