mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)
* [fx] move meta registration * [fx] fix tests. * [fx] fix test. * [fx] fix. * [meta] refactor meta registration.py. * [fx] add compatibility descriptions. * [fx] polish import. * [fx] add a decorator. * [fx] fix tests. * [fx] remove print. * [fx] edit raise error. * [fx] edit raise error. * [fx] add type hint. * [fx] fix import in experimental. * [rpc] remove color debug. * [meta] fix naming.
This commit is contained in:
@@ -1,18 +1,20 @@
|
||||
import copy
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
import torch.fx
|
||||
import colossalai
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.passes.algorithms import solver_rotor
|
||||
from colossalai.fx.passes.algorithms.operation import Sequence
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
import pytest
|
||||
from colossalai import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
||||
try:
|
||||
@@ -34,7 +36,7 @@ def _run_C_solver_consistency_test(rank=0):
|
||||
graph = tracer.trace(model, meta_args={"x": data})
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
if META_COMPATIBILITY:
|
||||
if is_compatible_with_meta():
|
||||
data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data_meta)
|
||||
|
||||
|
@@ -1,20 +1,22 @@
|
||||
from typing import Callable
|
||||
import copy
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
from torch.fx import GraphModule
|
||||
import colossalai
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
import pytest
|
||||
from colossalai import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from torch.fx import GraphModule
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
||||
try:
|
||||
@@ -54,8 +56,9 @@ def _is_graph_linearized(gm: GraphModule):
|
||||
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
|
||||
model_cls: Callable[[], torch.nn.Module]):
|
||||
criterion = torch.nn.MSELoss()
|
||||
data = torch.rand(2, 3, 32, 32)
|
||||
label = torch.rand(2, 5)
|
||||
m.cuda()
|
||||
data = torch.rand(2, 3, 32, 32).cuda()
|
||||
label = torch.rand(2, 5).cuda()
|
||||
loss = criterion(m(data), label)
|
||||
loss.backward()
|
||||
loss = criterion(gm(data), label)
|
||||
@@ -77,7 +80,7 @@ def _run_ckpt_solver(rank):
|
||||
m = model_cls(num_classes=5)
|
||||
graph = tracer.trace(root=m)
|
||||
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||
MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda'))
|
||||
MetaInfoProp(gm.cuda()).run(MetaTensor(data).cuda())
|
||||
codegen = ActivationCheckpointCodeGen()
|
||||
gm.graph.set_codegen(codegen)
|
||||
if solver == solver_rotor:
|
||||
|
@@ -1,13 +1,14 @@
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
import pytest
|
||||
import torch
|
||||
import torchvision.models as tm
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.algorithms import solver_rotor, linearize
|
||||
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
|
||||
import pytest
|
||||
from colossalai import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.passes.algorithms import linearize, solver_rotor
|
||||
from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
||||
try:
|
||||
|
@@ -1,13 +1,17 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
from torch.fx import symbolic_trace
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
|
||||
from colossalai.fx.passes.utils import get_comm_size
|
||||
from colossalai import META_COMPATIBILITY
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.passes.utils import get_comm_size
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
is_compatible = is_compatible_with_meta()
|
||||
if is_compatible:
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
MODEL_DIM = 16
|
||||
BATCH_SIZE = 8
|
||||
@@ -31,12 +35,12 @@ class MLP(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
def test_comm_size_compute():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
model = MLP(MODEL_DIM)
|
||||
input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu')
|
||||
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
|
||||
gm = symbolic_trace(model)
|
||||
if is_compatible:
|
||||
input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
|
||||
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
|
||||
|
@@ -1,12 +1,11 @@
|
||||
from typing import Any, Callable, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai import META_COMPATIBILITY
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
|
||||
if META_COMPATIBILITY:
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
aten = torch.ops.aten
|
||||
@@ -71,7 +70,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
|
||||
compare_all(x.grad, meta_x.grad)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||
def test_meta_aten():
|
||||
for (aten_op, requires_backward), v in registered_meta.items():
|
||||
for f, x in v:
|
||||
|
@@ -1,10 +1,10 @@
|
||||
import torchvision.models as tm
|
||||
import pytest
|
||||
import timm.models as tmm
|
||||
import torch
|
||||
from colossalai import META_COMPATIBILITY
|
||||
import pytest
|
||||
import torchvision.models as tm
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
|
||||
if META_COMPATIBILITY:
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
tm_models = [
|
||||
@@ -27,7 +27,7 @@ tmm_models = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||
def test_torchvision_models():
|
||||
for m in tm_models:
|
||||
model = m()
|
||||
@@ -35,7 +35,7 @@ def test_torchvision_models():
|
||||
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||
def test_timm_models():
|
||||
for m in tmm_models:
|
||||
model = m()
|
||||
|
@@ -1,10 +1,10 @@
|
||||
import torchvision.models as tm
|
||||
import pytest
|
||||
import timm.models as tmm
|
||||
import torch
|
||||
from colossalai import META_COMPATIBILITY
|
||||
import pytest
|
||||
import torchvision.models as tm
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
|
||||
if META_COMPATIBILITY:
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx import meta_trace
|
||||
|
||||
tm_models = [
|
||||
@@ -27,7 +27,7 @@ tmm_models = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||
def test_torchvision_models_trace():
|
||||
for m in tm_models:
|
||||
model = m()
|
||||
@@ -35,7 +35,7 @@ def test_torchvision_models_trace():
|
||||
graph = meta_trace(model, torch.device('cpu'), data)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
|
||||
def test_timm_models_trace():
|
||||
for m in tmm_models:
|
||||
model = m()
|
||||
|
@@ -1,7 +1,10 @@
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from colossalai import META_COMPATIBILITY
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
BATCH_SIZE = 2
|
||||
DIM_IN = 4
|
||||
@@ -18,8 +21,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
|
||||
def test_meta_info_prop():
|
||||
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
if is_compatible_with_meta():
|
||||
input_sample = MetaTensor(input_sample, fake_device='cpu')
|
||||
orig_output = model(input_sample)
|
||||
gm = symbolic_trace(model)
|
||||
|
@@ -1,19 +1,17 @@
|
||||
import os
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch.optim import SGD, Adam, RMSprop, Optimizer
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
import torch.distributed as dist
|
||||
from colorama import Back, Style
|
||||
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
import torch.distributed.rpc as rpc
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from torch import nn
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
from torch.optim import SGD, Adam, Optimizer, RMSprop
|
||||
|
||||
rpc_is_initialized = _is_current_rpc_agent_set
|
||||
|
||||
|
Reference in New Issue
Block a user