[Analyzer] fix analyzer tests (#3197)

This commit is contained in:
YuliangLiu0306
2023-03-22 13:38:11 +08:00
committed by GitHub
parent f57d34958b
commit 019a847432
7 changed files with 60 additions and 100 deletions

View File

@@ -3,6 +3,8 @@ import torch
from packaging import version
from torch.utils.checkpoint import checkpoint
from colossalai.testing.utils import parameterize
try:
from colossalai._analyzer.fx import symbolic_trace
except:
@@ -56,9 +58,13 @@ class SiuModel(torch.nn.Module):
self.linear = LinearModel(3, 3, bias)
self.conv = ConvModel(3, 6, 3, bias)
def forward(self, x, select=0):
def forward(self, x, select=torch.Tensor([0])):
x = self.linear(x)
x = checkpoint(self.conv, x, select)
if select:
x = checkpoint(self.conv, x, 0)
else:
x = checkpoint(self.conv, x, 1)
return x
@@ -75,10 +81,10 @@ class AddmmModel(torch.nn.Module):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
@pytest.mark.parametrize("select", [0, 1])
@parameterize("bias", [True, False])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
@parameterize("select", [torch.Tensor([0]), torch.Tensor([1])])
def test_siu_model(bias, bias_addition_split, shape, select):
model = SiuModel(bias=bias)
x = torch.rand(shape)
@@ -87,18 +93,18 @@ def test_siu_model(bias, bias_addition_split, shape, select):
concrete_args={'select': select},
trace_act_ckpt=True,
bias_addition_split=bias_addition_split)
assert torch.allclose(model(x, select), gm(x, select)), 'original model and traced model should be the same!'
assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!'
if bias and bias_addition_split:
assert '+' in gm.code, 'bias addition should be split!'
else:
assert '+' not in gm.code, 'bias addition should not be split!'
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize("alpha", [1, 2])
@pytest.mark.parametrize("beta", [1, 2])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3), (5, 5)])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize("alpha", [1, 2])
@parameterize("beta", [1, 2])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3), (5, 5)])
def test_addmm_model(alpha, beta, bias_addition_split, shape):
model = AddmmModel(alpha=alpha, beta=beta)
x = torch.rand(shape)
@@ -111,4 +117,5 @@ def test_addmm_model(alpha, beta, bias_addition_split, shape):
if __name__ == '__main__':
test_siu_model(True, True, (3, 3, 3))
test_siu_model()
test_addmm_model()