[fx/profiler] tuned the calculation of memory estimation (#1619)

* [fx] tuned the meta info and rotor solver.

* [fx] remove import.

* [fx] remove import.

* [fx] remove import.

* [fx] tune the meta calculations.

* [fx] polish comments.

* [fx] remove assertions.

* [fx] modify test cases.

* [fx] modify test cases.

* [fx] optimize import.

* [fx
This commit is contained in:
Super Daniel
2022-09-23 10:59:47 +08:00
committed by GitHub
parent f7f2248771
commit d967779a32
16 changed files with 413 additions and 207 deletions

View File

@@ -30,17 +30,17 @@ tmm_models = [
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_torchvision_models():
for m in tm_models:
model = m().to('meta')
data = torch.rand(1000, 3, 224, 224, device='meta')
model(MetaTensor(data)).sum().backward()
model = m()
data = torch.rand(100000, 3, 224, 224, device='meta')
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_timm_models():
for m in tmm_models:
model = m().to('meta')
data = torch.rand(1000, 3, 224, 224, device='meta')
model(MetaTensor(data)).sum().backward()
model = m()
data = torch.rand(100000, 3, 224, 224, device='meta')
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
if __name__ == '__main__':