[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)

* [fx] move meta registration

* [fx] fix tests.

* [fx] fix test.

* [fx] fix.

* [meta] refactor meta registration.py.

* [fx] add compatibility descriptions.

* [fx] polish import.

* [fx] add a decorator.

* [fx] fix tests.

* [fx] remove print.

* [fx] edit raise error.

* [fx] edit raise error.

* [fx] add type hint.

* [fx] fix import in experimental.

* [rpc] remove color debug.

* [meta] fix naming.
This commit is contained in:
Super Daniel
2022-10-18 10:44:23 +08:00
committed by GitHub
parent e8d8eda5e7
commit 393f594051
32 changed files with 351 additions and 310 deletions

View File

@@ -1,18 +1,20 @@
import copy
import colossalai
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
import torchvision.models as tm
import torch.fx
import colossalai
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.core import global_context as gpc
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.algorithms import solver_rotor
from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.core import global_context as gpc
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
try:
@@ -34,7 +36,7 @@ def _run_C_solver_consistency_test(rank=0):
graph = tracer.trace(model, meta_args={"x": data})
graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__)
if META_COMPATIBILITY:
if is_compatible_with_meta():
data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data_meta)

View File

@@ -1,20 +1,22 @@
from typing import Callable
import copy
import re
from typing import Callable
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from torch.fx import GraphModule
import colossalai
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from torch.fx import GraphModule
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
try:
@@ -54,8 +56,9 @@ def _is_graph_linearized(gm: GraphModule):
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
model_cls: Callable[[], torch.nn.Module]):
criterion = torch.nn.MSELoss()
data = torch.rand(2, 3, 32, 32)
label = torch.rand(2, 5)
m.cuda()
data = torch.rand(2, 3, 32, 32).cuda()
label = torch.rand(2, 5).cuda()
loss = criterion(m(data), label)
loss.backward()
loss = criterion(gm(data), label)
@@ -77,7 +80,7 @@ def _run_ckpt_solver(rank):
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda'))
MetaInfoProp(gm.cuda()).run(MetaTensor(data).cuda())
codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
if solver == solver_rotor:

View File

@@ -1,13 +1,14 @@
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
import pytest
import torch
import torchvision.models as tm
from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import solver_rotor, linearize
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx.passes.algorithms import linearize, solver_rotor
from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
try:

View File

@@ -1,13 +1,17 @@
import torch
import torch.nn as nn
import colossalai
import colossalai.nn as col_nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
from colossalai.fx.passes.utils import get_comm_size
from colossalai import META_COMPATIBILITY
import pytest
import torch
import torch.nn as nn
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.utils import get_comm_size
from torch.fx import symbolic_trace
is_compatible = is_compatible_with_meta()
if is_compatible:
from colossalai.fx.profiler import MetaTensor
MODEL_DIM = 16
BATCH_SIZE = 8
@@ -31,12 +35,12 @@ class MLP(torch.nn.Module):
return x
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_comm_size_compute():
from colossalai.fx.profiler import MetaTensor
model = MLP(MODEL_DIM)
input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu')
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
gm = symbolic_trace(model)
if is_compatible:
input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(input_sample)
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)

View File

@@ -1,12 +1,11 @@
from typing import Any, Callable, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai import META_COMPATIBILITY
import pytest
import torch
import torch.nn as nn
from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
aten = torch.ops.aten
@@ -71,7 +70,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
compare_all(x.grad, meta_x.grad)
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v:

View File

@@ -1,10 +1,10 @@
import torchvision.models as tm
import pytest
import timm.models as tmm
import torch
from colossalai import META_COMPATIBILITY
import pytest
import torchvision.models as tm
from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
tm_models = [
@@ -27,7 +27,7 @@ tmm_models = [
]
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_torchvision_models():
for m in tm_models:
model = m()
@@ -35,7 +35,7 @@ def test_torchvision_models():
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_timm_models():
for m in tmm_models:
model = m()

View File

@@ -1,10 +1,10 @@
import torchvision.models as tm
import pytest
import timm.models as tmm
import torch
from colossalai import META_COMPATIBILITY
import pytest
import torchvision.models as tm
from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx import meta_trace
tm_models = [
@@ -27,7 +27,7 @@ tmm_models = [
]
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_torchvision_models_trace():
for m in tm_models:
model = m()
@@ -35,7 +35,7 @@ def test_torchvision_models_trace():
graph = meta_trace(model, torch.device('cpu'), data)
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_timm_models_trace():
for m in tmm_models:
model = m()

View File

@@ -1,7 +1,10 @@
import torch
from torch.fx import symbolic_trace
from colossalai import META_COMPATIBILITY
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
from torch.fx import symbolic_trace
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
BATCH_SIZE = 2
DIM_IN = 4
@@ -18,8 +21,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
def test_meta_info_prop():
model = torch.nn.Linear(DIM_IN, DIM_OUT)
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
if is_compatible_with_meta():
input_sample = MetaTensor(input_sample, fake_device='cpu')
orig_output = model(input_sample)
gm = symbolic_trace(model)

View File

@@ -1,19 +1,17 @@
import os
import argparse
import os
import warnings
import torch
from torch import nn
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
from torch.optim import SGD, Adam, RMSprop, Optimizer
from torch._C._distributed_rpc import _is_current_rpc_agent_set
import torch.distributed as dist
from colorama import Back, Style
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.logging import disable_existing_loggers
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from colossalai import launch
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.pipeline_process_group import ppg
from torch import nn
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.optim import SGD, Adam, Optimizer, RMSprop
rpc_is_initialized = _is_current_rpc_agent_set