mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 15:11:20 +00:00
[fx/profiler] tuned the calculation of memory estimation (#1619)
* [fx] tuned the meta info and rotor solver. * [fx] remove import. * [fx] remove import. * [fx] remove import. * [fx] tune the meta calculations. * [fx] polish comments. * [fx] remove assertions. * [fx] modify test cases. * [fx] modify test cases. * [fx] optimize import. * [fx
This commit is contained in:
@@ -62,11 +62,8 @@ def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
|
||||
|
||||
def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
|
||||
x.requires_grad = requires_backward
|
||||
meta_x = MetaTensor(x.to('meta'))
|
||||
if isinstance(f, nn.Module):
|
||||
x_out, meta_out = f(x), f.to('meta')(meta_x)
|
||||
else:
|
||||
x_out, meta_out = f(x), f(meta_x)
|
||||
meta_x = MetaTensor(x)
|
||||
x_out, meta_out = f(x), f(meta_x)
|
||||
compare_all(x_out, meta_out)
|
||||
if requires_backward:
|
||||
x_out.sum().backward()
|
||||
|
Reference in New Issue
Block a user