mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[fx/profiler] tuned the calculation of memory estimation (#1619)
* [fx] tuned the meta info and rotor solver. * [fx] remove import. * [fx] remove import. * [fx] remove import. * [fx] tune the meta calculations. * [fx] polish comments. * [fx] remove assertions. * [fx] modify test cases. * [fx] modify test cases. * [fx] optimize import. * [fx
This commit is contained in:
@@ -30,17 +30,17 @@ tmm_models = [
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
def test_torchvision_models():
|
||||
for m in tm_models:
|
||||
model = m().to('meta')
|
||||
data = torch.rand(1000, 3, 224, 224, device='meta')
|
||||
model(MetaTensor(data)).sum().backward()
|
||||
model = m()
|
||||
data = torch.rand(100000, 3, 224, 224, device='meta')
|
||||
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
def test_timm_models():
|
||||
for m in tmm_models:
|
||||
model = m().to('meta')
|
||||
data = torch.rand(1000, 3, 224, 224, device='meta')
|
||||
model(MetaTensor(data)).sum().backward()
|
||||
model = m()
|
||||
data = torch.rand(100000, 3, 224, 224, device='meta')
|
||||
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user