[fx/profiler] tuned the calculation of memory estimation (#1619)

* [fx] tuned the meta info and rotor solver.

* [fx] remove import.

* [fx] remove import.

* [fx] remove import.

* [fx] tune the meta calculations.

* [fx] polish comments.

* [fx] remove assertions.

* [fx] modify test cases.

* [fx] modify test cases.

* [fx] optimize import.

* [fx
This commit is contained in:
Super Daniel
2022-09-23 10:59:47 +08:00
committed by GitHub
parent f7f2248771
commit d967779a32
16 changed files with 413 additions and 207 deletions

View File

@@ -62,11 +62,8 @@ def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
x.requires_grad = requires_backward
meta_x = MetaTensor(x.to('meta'))
if isinstance(f, nn.Module):
x_out, meta_out = f(x), f.to('meta')(meta_x)
else:
x_out, meta_out = f(x), f(meta_x)
meta_x = MetaTensor(x)
x_out, meta_out = f(x), f(meta_x)
compare_all(x_out, meta_out)
if requires_backward:
x_out.sum().backward()