[fx] provide a stable but not accurate enough version of profiler. (#1547)

* [fx] compute memory stat and flop count for MetaInfoProp.

* [fx] modify node attribute.

* [fx] modify ckpt_chen.

* [fx] fix compatibility.

* [fx] fix import error.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip if torch 1.11.0.

* [fx] recover MetaInfoProp support for PyTorch 1.11.

* [fx] provide a stable but not accurate enough version of profiler.

* [fx] provide a stable but not accurate enough version of profiler.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix import error.
This commit is contained in:
Super Daniel
2022-09-07 11:21:04 +08:00
committed by GitHub
parent 7d49e7b2db
commit 4f59693207
38 changed files with 776 additions and 263 deletions

View File

@@ -2,15 +2,12 @@ from typing import Any, Callable, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.fx.profiler import MetaTensor
from colossalai import META_COMPATIBILITY
import pytest
try:
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
INCOMPATIBLE = False # version > 1.12.0
except:
INCOMPATIBLE = True
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
aten = torch.ops.aten
@@ -56,7 +53,7 @@ registered_meta = {
}
def compare_all(tensor: torch.Tensor, meta_tensor: MetaTensor) -> Any:
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
assert tensor.stride() == meta_tensor.stride(
@@ -77,7 +74,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
compare_all(x.grad, meta_x.grad)
@pytest.mark.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v: