mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[fx] provide a stable but not accurate enough version of profiler. (#1547)
* [fx] compute memory stat and flop count for MetaInfoProp. * [fx] modify node attribute. * [fx] modify ckpt_chen. * [fx] fix compatibility. * [fx] fix import error. * [fx] skip test for MetaInfoProp. * [fx] skip test for MetaInfoProp. * [fx] skip test for MetaInfoProp. * [fx] skip test for MetaInfoProp. * [fx] skip if torch 1.11.0. * [fx] recover MetaInfoProp support for PyTorch 1.11. * [fx] provide a stable but not accurate enough version of profiler. * [fx] provide a stable but not accurate enough version of profiler. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix import error.
This commit is contained in:
@@ -2,15 +2,12 @@ from typing import Any, Callable, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai import META_COMPATIBILITY
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
||||
INCOMPATIBLE = False # version > 1.12.0
|
||||
except:
|
||||
INCOMPATIBLE = True
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
@@ -56,7 +53,7 @@ registered_meta = {
|
||||
}
|
||||
|
||||
|
||||
def compare_all(tensor: torch.Tensor, meta_tensor: MetaTensor) -> Any:
|
||||
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
|
||||
assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
|
||||
assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
|
||||
assert tensor.stride() == meta_tensor.stride(
|
||||
@@ -77,7 +74,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
|
||||
compare_all(x.grad, meta_x.grad)
|
||||
|
||||
|
||||
@pytest.mark.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
def test_meta_aten():
|
||||
for (aten_op, requires_backward), v in registered_meta.items():
|
||||
for f, x in v:
|
||||
|
Reference in New Issue
Block a user