mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[Analyzer] fix analyzer tests (#3197)
This commit is contained in:
@@ -3,6 +3,8 @@ import torch
|
||||
from packaging import version
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
||||
try:
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
except:
|
||||
@@ -56,9 +58,13 @@ class SiuModel(torch.nn.Module):
|
||||
self.linear = LinearModel(3, 3, bias)
|
||||
self.conv = ConvModel(3, 6, 3, bias)
|
||||
|
||||
def forward(self, x, select=0):
|
||||
def forward(self, x, select=torch.Tensor([0])):
|
||||
x = self.linear(x)
|
||||
x = checkpoint(self.conv, x, select)
|
||||
if select:
|
||||
x = checkpoint(self.conv, x, 0)
|
||||
else:
|
||||
x = checkpoint(self.conv, x, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -75,10 +81,10 @@ class AddmmModel(torch.nn.Module):
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
||||
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||
@pytest.mark.parametrize("select", [0, 1])
|
||||
@parameterize("bias", [True, False])
|
||||
@parameterize("bias_addition_split", [True, False])
|
||||
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||
@parameterize("select", [torch.Tensor([0]), torch.Tensor([1])])
|
||||
def test_siu_model(bias, bias_addition_split, shape, select):
|
||||
model = SiuModel(bias=bias)
|
||||
x = torch.rand(shape)
|
||||
@@ -87,18 +93,18 @@ def test_siu_model(bias, bias_addition_split, shape, select):
|
||||
concrete_args={'select': select},
|
||||
trace_act_ckpt=True,
|
||||
bias_addition_split=bias_addition_split)
|
||||
assert torch.allclose(model(x, select), gm(x, select)), 'original model and traced model should be the same!'
|
||||
assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!'
|
||||
if bias and bias_addition_split:
|
||||
assert '+' in gm.code, 'bias addition should be split!'
|
||||
else:
|
||||
assert '+' not in gm.code, 'bias addition should not be split!'
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize("alpha", [1, 2])
|
||||
@pytest.mark.parametrize("beta", [1, 2])
|
||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
||||
@pytest.mark.parametrize("shape", [(3, 3), (5, 5)])
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@parameterize("alpha", [1, 2])
|
||||
@parameterize("beta", [1, 2])
|
||||
@parameterize("bias_addition_split", [True, False])
|
||||
@parameterize("shape", [(3, 3), (5, 5)])
|
||||
def test_addmm_model(alpha, beta, bias_addition_split, shape):
|
||||
model = AddmmModel(alpha=alpha, beta=beta)
|
||||
x = torch.rand(shape)
|
||||
@@ -111,4 +117,5 @@ def test_addmm_model(alpha, beta, bias_addition_split, shape):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_siu_model(True, True, (3, 3, 3))
|
||||
test_siu_model()
|
||||
test_addmm_model()
|
||||
|
Reference in New Issue
Block a user