mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
[FX] refactor experimental tracer and adapt it with hf models (#3157)
* pass gpt trace and meta_prop * pass t5 trace and meta_prop * [FX] refactor experimental tracer and adapt it with hf models * pass all mainstream model zoo * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * skip tests * fix CI * using packaging version * polish
This commit is contained in:
parent
b429529365
commit
f57d34958b
@ -6,11 +6,15 @@
|
|||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
||||||
|
try:
|
||||||
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
||||||
|
except AttributeError:
|
||||||
|
meta_lib = None
|
||||||
|
|
||||||
meta_table = {}
|
meta_table = {}
|
||||||
|
|
||||||
@ -50,6 +54,7 @@ def register_meta(op, register_dispatcher=True):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||||
# ============================== Convolutions ======================================
|
# ============================== Convolutions ======================================
|
||||||
# https://github.com/pytorch/pytorch/pull/79834
|
# https://github.com/pytorch/pytorch/pull/79834
|
||||||
@register_meta(aten.convolution.default)
|
@register_meta(aten.convolution.default)
|
||||||
@ -178,7 +183,6 @@ def meta_conv(
|
|||||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten._convolution.default)
|
@register_meta(aten._convolution.default)
|
||||||
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
|
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
|
||||||
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
|
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
|
||||||
@ -186,13 +190,11 @@ def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Ten
|
|||||||
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.convolution_backward.default)
|
@register_meta(aten.convolution_backward.default)
|
||||||
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
||||||
padding, dilation, transposed, output_padding, groups, output_mask):
|
padding, dilation, transposed, output_padding, groups, output_mask):
|
||||||
return new_like(input), new_like(weight), new((bias_sizes))
|
return new_like(input), new_like(weight), new((bias_sizes))
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
|
||||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||||
def meta_adaptive_avg_pool2d_backward(
|
def meta_adaptive_avg_pool2d_backward(
|
||||||
@ -201,7 +203,6 @@ def meta_adaptive_avg_pool2d_backward(
|
|||||||
):
|
):
|
||||||
return new_like(input)
|
return new_like(input)
|
||||||
|
|
||||||
|
|
||||||
# ================================ RNN =============================================
|
# ================================ RNN =============================================
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||||
@register_meta(aten._cudnn_rnn.default)
|
@register_meta(aten._cudnn_rnn.default)
|
||||||
@ -254,7 +255,6 @@ def meta_cuda_rnn(
|
|||||||
|
|
||||||
return output, hy, cy, reserve, weight_buf
|
return output, hy, cy, reserve, weight_buf
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||||
@register_meta(aten._cudnn_rnn_backward.default)
|
@register_meta(aten._cudnn_rnn_backward.default)
|
||||||
def meta_cudnn_rnn_backward(input: torch.Tensor,
|
def meta_cudnn_rnn_backward(input: torch.Tensor,
|
||||||
@ -267,7 +267,6 @@ def meta_cudnn_rnn_backward(input: torch.Tensor,
|
|||||||
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
|
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
|
||||||
()) # (grad_input, grad_weight, grad_hx, grad_cx)
|
()) # (grad_input, grad_weight, grad_hx, grad_cx)
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
||||||
# ============================== Activations =======================================
|
# ============================== Activations =======================================
|
||||||
_unregistered_ewise = [
|
_unregistered_ewise = [
|
||||||
@ -280,12 +279,10 @@ _unregistered_ewise = [
|
|||||||
aten.hardtanh_backward.default,
|
aten.hardtanh_backward.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@register_meta(_unregistered_ewise)
|
@register_meta(_unregistered_ewise)
|
||||||
def meta_unregistered_ewise(input: torch.Tensor, *args):
|
def meta_unregistered_ewise(input: torch.Tensor, *args):
|
||||||
return new_like(input)
|
return new_like(input)
|
||||||
|
|
||||||
|
|
||||||
# ============================== Normalization =====================================
|
# ============================== Normalization =====================================
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||||
@register_meta(aten.native_batch_norm.default)
|
@register_meta(aten.native_batch_norm.default)
|
||||||
@ -293,14 +290,12 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
|
|||||||
n_input = input.size(1)
|
n_input = input.size(1)
|
||||||
return new_like(input), new((n_input)), new((n_input))
|
return new_like(input), new((n_input)), new((n_input))
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||||
@register_meta(aten.native_batch_norm_backward.default)
|
@register_meta(aten.native_batch_norm_backward.default)
|
||||||
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
|
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||||
save_invstd, train, eps, output_mask):
|
save_mean, save_invstd, train, eps, output_mask):
|
||||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||||
@register_meta(aten.cudnn_batch_norm.default)
|
@register_meta(aten.cudnn_batch_norm.default)
|
||||||
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||||
@ -308,7 +303,6 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
|
|||||||
return new_like(input), new((n_input)), new((n_input)), new(
|
return new_like(input), new((n_input)), new((n_input)), new(
|
||||||
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
|
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||||
# NB: CuDNN only implements the backward algorithm for batchnorm
|
# NB: CuDNN only implements the backward algorithm for batchnorm
|
||||||
# in training mode (evaluation mode batchnorm has a different algorithm),
|
# in training mode (evaluation mode batchnorm has a different algorithm),
|
||||||
@ -318,21 +312,18 @@ def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.
|
|||||||
save_mean, save_invstd, eps, reserve):
|
save_mean, save_invstd, eps, reserve):
|
||||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||||
@register_meta(aten.native_layer_norm.default)
|
@register_meta(aten.native_layer_norm.default)
|
||||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||||
bs, n_input = input.size(0), input.size(1)
|
bs, n_input = input.size(0), input.size(1)
|
||||||
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
|
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||||
@register_meta(aten.native_layer_norm_backward.default)
|
@register_meta(aten.native_layer_norm_backward.default)
|
||||||
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
||||||
grad_input_mask):
|
grad_input_mask):
|
||||||
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
|
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
|
||||||
|
|
||||||
|
|
||||||
# ================================== Misc ==========================================
|
# ================================== Misc ==========================================
|
||||||
# Maybe incorrect
|
# Maybe incorrect
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp
|
||||||
@ -340,32 +331,27 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
|
|||||||
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
|
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
|
||||||
return new_like(input)
|
return new_like(input)
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||||
@register_meta(aten.eye.m_out)
|
@register_meta(aten.eye.m_out)
|
||||||
def meta_eye(n: int, m: int, out: torch.Tensor):
|
def meta_eye(n: int, m: int, out: torch.Tensor):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||||
@register_meta(aten.roll.default)
|
@register_meta(aten.roll.default)
|
||||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
|
||||||
@register_meta(aten._local_scalar_dense.default)
|
@register_meta(aten._local_scalar_dense.default)
|
||||||
def meta_local_scalar_dense(self: torch.Tensor):
|
def meta_local_scalar_dense(self: torch.Tensor):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
||||||
@register_meta(aten.where.self)
|
@register_meta(aten.where.self)
|
||||||
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
||||||
result_type = torch.result_type(self, other)
|
result_type = torch.result_type(self, other)
|
||||||
return new_like(condition + self + other, dtype=result_type)
|
return new_like(condition + self + other, dtype=result_type)
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.index.Tensor)
|
@register_meta(aten.index.Tensor)
|
||||||
def meta_index_Tensor(self, indices):
|
def meta_index_Tensor(self, indices):
|
||||||
assert indices, "at least one index must be provided"
|
assert indices, "at least one index must be provided"
|
||||||
@ -455,7 +441,6 @@ def meta_index_Tensor(self, indices):
|
|||||||
replacement_shape = list(index.shape)
|
replacement_shape = list(index.shape)
|
||||||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
return self.new_empty(before_shape + replacement_shape + after_shape)
|
||||||
|
|
||||||
|
|
||||||
# ============================== Embedding =========================================
|
# ============================== Embedding =========================================
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||||
@register_meta(aten.embedding_dense_backward.default)
|
@register_meta(aten.embedding_dense_backward.default)
|
||||||
@ -466,7 +451,6 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
|
|||||||
device=grad_output.device,
|
device=grad_output.device,
|
||||||
layout=grad_output.layout)
|
layout=grad_output.layout)
|
||||||
|
|
||||||
|
|
||||||
# ============================== Dropout ===========================================
|
# ============================== Dropout ===========================================
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||||
@register_meta(aten.native_dropout.default)
|
@register_meta(aten.native_dropout.default)
|
||||||
@ -474,7 +458,6 @@ def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = Fal
|
|||||||
# notice that mask is bool
|
# notice that mask is bool
|
||||||
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||||
@register_meta(aten.native_dropout_backward.default)
|
@register_meta(aten.native_dropout_backward.default)
|
||||||
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
||||||
@ -49,6 +50,7 @@ _DistCommMethod = [
|
|||||||
"scatter",
|
"scatter",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||||
# TODO: dive deep here
|
# TODO: dive deep here
|
||||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
||||||
_AliasATen = [
|
_AliasATen = [
|
||||||
@ -86,3 +88,7 @@ _MaybeInplaceATen = [
|
|||||||
aten.unsqueeze.default,
|
aten.unsqueeze.default,
|
||||||
aten.as_strided.default,
|
aten.as_strided.default,
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
_AliasATen = []
|
||||||
|
_InplaceATen = []
|
||||||
|
_MaybeInplaceATen = []
|
||||||
|
@ -11,6 +11,7 @@ from numbers import Number
|
|||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from .meta_tensor import MetaTensor
|
from .meta_tensor import MetaTensor
|
||||||
@ -403,6 +404,7 @@ def zero_flop_jit(*args):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||||
flop_mapping = {
|
flop_mapping = {
|
||||||
# gemm
|
# gemm
|
||||||
aten.mm.default: matmul_flop_jit,
|
aten.mm.default: matmul_flop_jit,
|
||||||
@ -534,3 +536,7 @@ zero_flop_aten = [
|
|||||||
|
|
||||||
for op in zero_flop_aten:
|
for op in zero_flop_aten:
|
||||||
flop_mapping[op] = zero_flop_jit
|
flop_mapping[op] = zero_flop_jit
|
||||||
|
else:
|
||||||
|
flop_mapping = {}
|
||||||
|
elementwise_flop_aten = {}
|
||||||
|
zero_flop_aten = {}
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from .bias_addition import *
|
|
||||||
from .node_util import MetaInfo
|
from .node_util import MetaInfo
|
||||||
from .symbolic_profile import symbolic_profile
|
from .symbolic_profile import symbolic_profile
|
||||||
from .symbolic_trace import symbolic_trace
|
from .tracer.symbolic_trace import symbolic_trace
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
|
import linecache
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
@ -6,11 +9,74 @@ from typing import Any, Dict, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.fx.graph import PythonCode, _PyTreeCodeGen
|
from torch.fx.graph import PythonCode
|
||||||
from torch.fx.graph_module import _exec_with_source, _forward_from_src, _WrappedCall
|
|
||||||
|
try:
|
||||||
|
from torch.fx.graph import _PyTreeCodeGen
|
||||||
|
SUPPORT_PT_CODEGEN = True
|
||||||
|
except ImportError:
|
||||||
|
SUPPORT_PT_CODEGEN = False
|
||||||
|
|
||||||
|
from torch.fx.graph_module import _exec_with_source, _forward_from_src
|
||||||
from torch.nn.modules.module import _addindent
|
from torch.nn.modules.module import _addindent
|
||||||
|
|
||||||
|
|
||||||
|
# This is a copy of torch.fx.graph_module._WrappedCall.
|
||||||
|
# It should be removed when we stop supporting torch < 1.12.0.
|
||||||
|
class _WrappedCall:
|
||||||
|
|
||||||
|
def __init__(self, cls, cls_call):
|
||||||
|
self.cls = cls
|
||||||
|
self.cls_call = cls_call
|
||||||
|
|
||||||
|
# Previously, if an error occurred when valid
|
||||||
|
# symbolically-traced code was run with an invalid input, the
|
||||||
|
# user would see the source of the error as coming from
|
||||||
|
# `File "<eval_with_key_N">`, where N is some number. We use
|
||||||
|
# this function to generate a more informative error message. We
|
||||||
|
# return the traceback itself, a message explaining that the
|
||||||
|
# error occurred in a traced Module's generated forward
|
||||||
|
# function, and five lines of context surrounding the faulty
|
||||||
|
# line
|
||||||
|
@staticmethod
|
||||||
|
def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
|
||||||
|
# auxiliary variables (for readability)
|
||||||
|
err_lineno = frame_summary.lineno
|
||||||
|
assert err_lineno is not None
|
||||||
|
line = frame_summary.line
|
||||||
|
assert line is not None
|
||||||
|
err_line_len = len(line)
|
||||||
|
all_src_lines = linecache.getlines(frame_summary.filename)
|
||||||
|
|
||||||
|
# constituent substrings of the error message
|
||||||
|
tb_repr = traceback.format_exc()
|
||||||
|
custom_msg = ("Call using an FX-traced Module, "
|
||||||
|
f"line {err_lineno} of the traced Module's "
|
||||||
|
"generated forward function:")
|
||||||
|
before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
|
||||||
|
marker = "~" * err_line_len + "~~~ <--- HERE"
|
||||||
|
err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
|
||||||
|
|
||||||
|
# joined message
|
||||||
|
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
|
||||||
|
|
||||||
|
def __call__(self, obj, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
if self.cls_call is not None:
|
||||||
|
return self.cls_call(obj, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
|
||||||
|
except Exception as e:
|
||||||
|
assert e.__traceback__
|
||||||
|
topmost_framesummary: traceback.FrameSummary = \
|
||||||
|
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
|
||||||
|
if "eval_with_key" in topmost_framesummary.filename:
|
||||||
|
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
|
||||||
|
raise e.with_traceback(None)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
class ColoGraphModule(torch.fx.GraphModule):
|
class ColoGraphModule(torch.fx.GraphModule):
|
||||||
"""
|
"""
|
||||||
ColoGraphGraphModule is an nn.Module generated from an fx.Graph.
|
ColoGraphGraphModule is an nn.Module generated from an fx.Graph.
|
||||||
@ -65,7 +131,7 @@ class ColoGraphModule(torch.fx.GraphModule):
|
|||||||
called after editing the contained ``graph``, otherwise the generated
|
called after editing the contained ``graph``, otherwise the generated
|
||||||
code of this ``GraphModule`` will be out of date.
|
code of this ``GraphModule`` will be out of date.
|
||||||
"""
|
"""
|
||||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||||
python_code = self._graph.python_code(root_module='self')
|
python_code = self._graph.python_code(root_module='self')
|
||||||
|
@ -20,7 +20,7 @@ def union(a, b):
|
|||||||
return {**a, **b}
|
return {**a, **b}
|
||||||
|
|
||||||
|
|
||||||
def compute_size_in_bytes(elem: torch.Tensor | Dict | List | Tuple | int) -> int:
|
def compute_size_in_bytes(elem: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||||
"""Compute the size of a tensor or a collection of tensors in bytes.
|
"""Compute the size of a tensor or a collection of tensors in bytes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -195,8 +195,8 @@ class MetaInfo:
|
|||||||
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
|
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
|
||||||
if self.output_size:
|
if self.output_size:
|
||||||
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
|
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
|
||||||
if self.total_size:
|
# if self.total_size:
|
||||||
s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
|
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
|
||||||
if self.temp_size:
|
if self.temp_size:
|
||||||
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
|
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
|
||||||
if self.backward_size:
|
if self.backward_size:
|
||||||
|
@ -111,7 +111,24 @@ class ShapeProp(torch.fx.Interpreter):
|
|||||||
with self.global_hook:
|
with self.global_hook:
|
||||||
r = getattr(self, n.op)(n.target, args, kwargs)
|
r = getattr(self, n.op)(n.target, args, kwargs)
|
||||||
|
|
||||||
unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
|
def unwrap_fn(elem):
|
||||||
|
|
||||||
|
def _convert_meta(t: torch.Tensor):
|
||||||
|
if t.device == 'meta':
|
||||||
|
return t
|
||||||
|
else:
|
||||||
|
return t.to('meta')
|
||||||
|
|
||||||
|
if isinstance(elem, MetaTensor):
|
||||||
|
return _convert_meta(elem._tensor)
|
||||||
|
|
||||||
|
elif isinstance(elem, torch.Tensor):
|
||||||
|
return _convert_meta(elem)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return elem
|
||||||
|
|
||||||
|
# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
|
||||||
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
|
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
|
||||||
n_info = MetaInfo(n)
|
n_info = MetaInfo(n)
|
||||||
n_info.outputs = _normalize_tuple(r)
|
n_info.outputs = _normalize_tuple(r)
|
||||||
|
2
colossalai/_analyzer/fx/tracer/__init__.py
Normal file
2
colossalai/_analyzer/fx/tracer/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .bias_addition import *
|
||||||
|
from .custom_leaf_module import *
|
@ -4,11 +4,10 @@ graph construction to deal with the compatibility between bias-addition and all-
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.modules.utils import _pair, _single, _triple
|
from torch.nn.modules.utils import _pair, _single, _triple
|
||||||
|
|
||||||
from .symbolic_trace import register_tracer_impl
|
from .tracer import register_tracer_impl
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
|
|
29
colossalai/_analyzer/fx/tracer/custom_leaf_module.py
Normal file
29
colossalai/_analyzer/fx/tracer/custom_leaf_module.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from .tracer import register_leaf_module, register_leaf_module_impl
|
||||||
|
|
||||||
|
try:
|
||||||
|
import apex
|
||||||
|
register_leaf_module(apex.normalization.FusedLayerNorm)
|
||||||
|
register_leaf_module(apex.normalization.FusedRMSNorm)
|
||||||
|
register_leaf_module(apex.normalization.MixedFusedLayerNorm)
|
||||||
|
register_leaf_module(apex.normalization.MixedFusedRMSNorm)
|
||||||
|
|
||||||
|
@register_leaf_module_impl(apex.normalization.FusedLayerNorm)
|
||||||
|
@register_leaf_module_impl(apex.normalization.FusedRMSNorm)
|
||||||
|
@register_leaf_module_impl(apex.normalization.MixedFusedLayerNorm)
|
||||||
|
@register_leaf_module_impl(apex.normalization.MixedFusedRMSNorm)
|
||||||
|
def torch_nn_normalize(self, input: torch.Tensor):
|
||||||
|
# check shape
|
||||||
|
if isinstance(self, torch.nn.BatchNorm1d):
|
||||||
|
assert input.dim() in [2, 3]
|
||||||
|
elif isinstance(self, torch.nn.BatchNorm2d):
|
||||||
|
assert input.dim() == 4
|
||||||
|
elif isinstance(self, torch.nn.BatchNorm3d):
|
||||||
|
assert input.dim() == 5
|
||||||
|
|
||||||
|
# normalization maintain the same shape as the input
|
||||||
|
return input.clone()
|
||||||
|
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass
|
112
colossalai/_analyzer/fx/tracer/proxy.py
Normal file
112
colossalai/_analyzer/fx/tracer/proxy.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import operator
|
||||||
|
from typing import Any, Callable, Dict, Optional, Set, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.fx import Graph, Node, Proxy, Tracer
|
||||||
|
from torch.fx.graph import _Namespace
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses import MetaTensor
|
||||||
|
|
||||||
|
Target = Union[Callable[..., Any], str]
|
||||||
|
|
||||||
|
|
||||||
|
class ColoProxy(Proxy):
|
||||||
|
_func_dispatch: Dict[Target, Callable[..., Any]] = {}
|
||||||
|
|
||||||
|
def __init__(self, *args, data=None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._meta_data = data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def meta_data(self):
|
||||||
|
return self._meta_data
|
||||||
|
|
||||||
|
@meta_data.setter
|
||||||
|
def meta_data(self, args):
|
||||||
|
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||||||
|
self._meta_data = tree_map(wrap_fn, args)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
|
||||||
|
kwargs = {} if kwargs is None else kwargs
|
||||||
|
if orig_method in cls._func_dispatch:
|
||||||
|
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
|
||||||
|
proxy = impl(*args, **kwargs)
|
||||||
|
cls._func_dispatch[orig_method] = impl
|
||||||
|
return proxy
|
||||||
|
else:
|
||||||
|
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
|
||||||
|
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
|
||||||
|
if proxy.meta_data is None:
|
||||||
|
proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_torch_proxy(cls, proxy: Proxy):
|
||||||
|
return cls(proxy.node, proxy.tracer)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.meta_data)
|
||||||
|
|
||||||
|
def __int__(self):
|
||||||
|
return int(self.meta_data)
|
||||||
|
|
||||||
|
def __index__(self):
|
||||||
|
try:
|
||||||
|
return int(self.meta_data)
|
||||||
|
except:
|
||||||
|
return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
|
||||||
|
|
||||||
|
def __float__(self):
|
||||||
|
return float(self.meta_data)
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return self.meta_data
|
||||||
|
|
||||||
|
def __getattr__(self, k):
|
||||||
|
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
|
||||||
|
proxy.meta_data = self._meta_data
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
if self.node.op == "placeholder":
|
||||||
|
# this is used to handle like
|
||||||
|
# if x in kwargs
|
||||||
|
# we don't handle this case for now
|
||||||
|
return False
|
||||||
|
return super().__contains__(key)
|
||||||
|
|
||||||
|
def __isinstancecheck__(self, type):
|
||||||
|
return isinstance(self.meta_data, type)
|
||||||
|
|
||||||
|
|
||||||
|
class ColoAttribute(ColoProxy):
|
||||||
|
|
||||||
|
def __init__(self, root, attr: str, data=None):
|
||||||
|
self.root = root
|
||||||
|
self.attr = attr
|
||||||
|
self.tracer = root.tracer
|
||||||
|
self._meta_data = data
|
||||||
|
self._node: Optional[Node] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node(self):
|
||||||
|
# the node for attributes is added lazily, since most will just be method calls
|
||||||
|
# which do not rely on the getitem call
|
||||||
|
if self._node is None:
|
||||||
|
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
||||||
|
return self._node
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"ColoAttribute({self.node.name}, attr={self.attr})"
|
157
colossalai/_analyzer/fx/tracer/symbolic_trace.py
Normal file
157
colossalai/_analyzer/fx/tracer/symbolic_trace.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.fx import Tracer
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses import MetaTensor
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ..codegen import ActivationCheckpointCodeGen
|
||||||
|
SUPPORT_ACTIVATION = True
|
||||||
|
except:
|
||||||
|
SUPPORT_ACTIVATION = False
|
||||||
|
from ..graph_module import ColoGraphModule
|
||||||
|
from .tracer import ColoTracer
|
||||||
|
|
||||||
|
|
||||||
|
def _default_device():
|
||||||
|
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
|
||||||
|
|
||||||
|
def _current_device(module: torch.nn.Module):
|
||||||
|
try:
|
||||||
|
return next(module.parameters()).device
|
||||||
|
except:
|
||||||
|
return _default_device()
|
||||||
|
|
||||||
|
|
||||||
|
def symbolic_trace(
|
||||||
|
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||||
|
concrete_args: Optional[Dict[str, Any]] = None,
|
||||||
|
meta_args: Optional[Dict[str, Any]] = None,
|
||||||
|
trace_act_ckpt: bool = False,
|
||||||
|
bias_addition_split: bool = False,
|
||||||
|
) -> ColoGraphModule:
|
||||||
|
"""
|
||||||
|
Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
|
||||||
|
attached to the ``Node``s.
|
||||||
|
|
||||||
|
Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
|
||||||
|
(https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
|
||||||
|
|
||||||
|
This tracer is able to trace basic control flow and for loops.
|
||||||
|
|
||||||
|
It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
|
||||||
|
(See ./bias_addition.py for more details).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
1. Tracing a ``torch.nn.Module`` with control flow.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(2, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if x.size(0) > 1:
|
||||||
|
x = x.sum(dim=0)
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
|
||||||
|
|
||||||
|
# traced code like:
|
||||||
|
# def forward(self, x):
|
||||||
|
# linear_1 = self.linear(x)
|
||||||
|
# return linear_1
|
||||||
|
|
||||||
|
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
|
||||||
|
|
||||||
|
# traced code like:
|
||||||
|
# def forward(self, x):
|
||||||
|
# sum = x.sum(dim=0); x = None
|
||||||
|
# linear = self.linear(sum); sum = None
|
||||||
|
# return linear
|
||||||
|
|
||||||
|
2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(2, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
def custom_forward(x):
|
||||||
|
return self.linear(x)
|
||||||
|
return torch.utils.checkpoint.checkpoint(custom_forward, x)
|
||||||
|
|
||||||
|
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
|
||||||
|
|
||||||
|
# traced code like:
|
||||||
|
# def checkpoint_0(self, x):
|
||||||
|
# linear = self.linear(x); x = None
|
||||||
|
# return linear
|
||||||
|
#
|
||||||
|
# def forward(self, x):
|
||||||
|
# linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
|
||||||
|
# return linear
|
||||||
|
|
||||||
|
3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(2, 2, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
|
||||||
|
|
||||||
|
# traced code like:
|
||||||
|
# def forward(self, x):
|
||||||
|
# linear_bias = self.linear.bias
|
||||||
|
# linear_weight = self.linear.weight
|
||||||
|
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
|
||||||
|
# add = linear + linear_bias; linear = linear_bias = None
|
||||||
|
# return add
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
|
||||||
|
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
|
||||||
|
Defaults to {}.
|
||||||
|
meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
|
||||||
|
for tracing control flow. Defaults to {}.
|
||||||
|
trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
|
||||||
|
Defaults to False.
|
||||||
|
bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
|
||||||
|
|
||||||
|
Remarks:
|
||||||
|
This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
|
||||||
|
any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
|
||||||
|
repo. We welcome any feedback and contributions to enhance the extensibility of
|
||||||
|
Colossal-AI.
|
||||||
|
"""
|
||||||
|
if meta_args:
|
||||||
|
device, orig_device = _default_device(), _current_device(root)
|
||||||
|
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
|
||||||
|
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
|
||||||
|
bias_addition_split=bias_addition_split).trace(root.to(device),
|
||||||
|
concrete_args=concrete_args,
|
||||||
|
meta_args=tree_map(wrap_fn, meta_args))
|
||||||
|
if trace_act_ckpt and SUPPORT_ACTIVATION:
|
||||||
|
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||||
|
root.to(orig_device)
|
||||||
|
else:
|
||||||
|
graph = Tracer().trace(root, concrete_args=concrete_args)
|
||||||
|
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||||
|
return ColoGraphModule(root, graph, name)
|
@ -1,28 +1,19 @@
|
|||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import operator
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.fx import Graph, Node, Proxy, Tracer
|
from torch.fx import Graph, Node, Proxy, Tracer
|
||||||
from torch.fx.graph import _Namespace
|
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai._analyzer._subclasses import MetaTensor, _TensorPropertyMethod, _TorchFactoryMethod
|
from colossalai._analyzer._subclasses import _TensorPropertyMethod, _TorchFactoryMethod
|
||||||
|
|
||||||
from .codegen import ActivationCheckpointCodeGen
|
from ..node_util import MetaInfo
|
||||||
from .graph_module import ColoGraphModule
|
from .proxy import ColoProxy
|
||||||
from .node_util import MetaInfo
|
|
||||||
|
|
||||||
Target = Union[Callable[..., Any], str]
|
Target = Union[Callable[..., Any], str]
|
||||||
Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
|
|
||||||
List[Any], # actually Argument
|
|
||||||
Dict[str, Any], # actually Argument
|
|
||||||
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
|
|
||||||
'Node',]]
|
|
||||||
zeros = torch.zeros
|
|
||||||
|
|
||||||
|
|
||||||
def _truncate_suffix(s: str):
|
def _truncate_suffix(s: str):
|
||||||
@ -32,17 +23,6 @@ def _truncate_suffix(s: str):
|
|||||||
return re.sub(r'_\d+$', '', s)
|
return re.sub(r'_\d+$', '', s)
|
||||||
|
|
||||||
|
|
||||||
def _default_device():
|
|
||||||
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
|
||||||
|
|
||||||
|
|
||||||
def _current_device(module):
|
|
||||||
try:
|
|
||||||
return next(module.parameters()).device
|
|
||||||
except:
|
|
||||||
return _default_device()
|
|
||||||
|
|
||||||
|
|
||||||
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
|
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
|
||||||
|
|
||||||
def wrapper(impl):
|
def wrapper(impl):
|
||||||
@ -70,149 +50,6 @@ def register_non_leaf_module(module: nn.Module):
|
|||||||
ColoTracer._custom_non_leaf_module.add(module)
|
ColoTracer._custom_non_leaf_module.add(module)
|
||||||
|
|
||||||
|
|
||||||
class ColoProxy(Proxy):
|
|
||||||
_func_dispatch: Dict[Target, Callable[..., Any]] = {}
|
|
||||||
|
|
||||||
def __init__(self, *args, data=None, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._meta_data = data
|
|
||||||
|
|
||||||
@property
|
|
||||||
def meta_data(self):
|
|
||||||
return self._meta_data
|
|
||||||
|
|
||||||
@meta_data.setter
|
|
||||||
def meta_data(self, args):
|
|
||||||
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
|
||||||
self._meta_data = tree_map(wrap_fn, args)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
|
|
||||||
kwargs = {} if kwargs is None else kwargs
|
|
||||||
if orig_method in cls._func_dispatch:
|
|
||||||
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
|
|
||||||
proxy = impl(*args, **kwargs)
|
|
||||||
cls._func_dispatch[orig_method] = impl
|
|
||||||
return proxy
|
|
||||||
else:
|
|
||||||
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
|
|
||||||
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
|
|
||||||
if proxy.meta_data is None:
|
|
||||||
proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
|
||||||
return proxy
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_torch_proxy(cls, proxy: Proxy):
|
|
||||||
return cls(proxy.node, proxy.tracer)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.meta_data)
|
|
||||||
|
|
||||||
def __int__(self):
|
|
||||||
return int(self.meta_data)
|
|
||||||
|
|
||||||
def __index__(self):
|
|
||||||
try:
|
|
||||||
return int(self.meta_data)
|
|
||||||
except:
|
|
||||||
return zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
|
|
||||||
|
|
||||||
def __float__(self):
|
|
||||||
return float(self.meta_data)
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
return self.meta_data
|
|
||||||
|
|
||||||
def __getattr__(self, k):
|
|
||||||
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
|
|
||||||
proxy.meta_data = self._meta_data
|
|
||||||
return proxy
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
|
||||||
if self.node.op == "placeholder":
|
|
||||||
# this is used to handle like
|
|
||||||
# if x in kwargs
|
|
||||||
# we don't handle this case for now
|
|
||||||
return False
|
|
||||||
return super().__contains__(key)
|
|
||||||
|
|
||||||
def __isinstancecheck__(self, type):
|
|
||||||
return isinstance(self.meta_data, type)
|
|
||||||
|
|
||||||
def size(self, dim=None):
|
|
||||||
if self._meta_data is None:
|
|
||||||
return self._meta_data.size(*[dim] if dim else [])
|
|
||||||
return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
|
|
||||||
|
|
||||||
def dim(self):
|
|
||||||
if self._meta_data is not None:
|
|
||||||
return self._meta_data.dim()
|
|
||||||
return self.tracer.create_proxy('call_method', 'dim', (self,), {})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
if self._meta_data is not None:
|
|
||||||
return self._meta_data.shape
|
|
||||||
return self.tracer.create_proxy('call_function', getattr, (self, 'shape'), {})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ndim(self):
|
|
||||||
if self._meta_data is not None:
|
|
||||||
return self._meta_data.ndim
|
|
||||||
return self.tracer.create_proxy('call_function', getattr, (self, 'ndim'), {})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
if self._meta_data is not None:
|
|
||||||
return self._meta_data.device
|
|
||||||
return self.tracer.create_proxy('call_function', getattr, (self, 'device'), {})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
if self._meta_data is not None:
|
|
||||||
return self._meta_data.dtype
|
|
||||||
return self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {})
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs})
|
|
||||||
|
|
||||||
def cpu(self, *args, **kwargs):
|
|
||||||
return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs})
|
|
||||||
|
|
||||||
def cuda(self, *args, **kwargs):
|
|
||||||
return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs})
|
|
||||||
|
|
||||||
|
|
||||||
class ColoAttribute(ColoProxy):
|
|
||||||
|
|
||||||
def __init__(self, root, attr: str, data=None):
|
|
||||||
self.root = root
|
|
||||||
self.attr = attr
|
|
||||||
self.tracer = root.tracer
|
|
||||||
self._meta_data = data
|
|
||||||
self._node: Optional[Node] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def node(self):
|
|
||||||
# the node for attributes is added lazily, since most will just be method calls
|
|
||||||
# which do not rely on the getitem call
|
|
||||||
if self._node is None:
|
|
||||||
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
|
||||||
return self._node
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"ColoAttribute({self.node.name}, attr={self.attr})"
|
|
||||||
|
|
||||||
|
|
||||||
class ColoTracer(Tracer):
|
class ColoTracer(Tracer):
|
||||||
_custom_leaf_module: Set[Type[nn.Module]] = set()
|
_custom_leaf_module: Set[Type[nn.Module]] = set()
|
||||||
_custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {}
|
_custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {}
|
||||||
@ -249,7 +86,6 @@ class ColoTracer(Tracer):
|
|||||||
# we will enter the module and split the bias-addition ops
|
# we will enter the module and split the bias-addition ops
|
||||||
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
|
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# user can specify which modules are leaf modules and which are not
|
# user can specify which modules are leaf modules and which are not
|
||||||
return (type(m) not in self._custom_non_leaf_module
|
return (type(m) not in self._custom_non_leaf_module
|
||||||
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
|
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
|
||||||
@ -306,9 +142,13 @@ class ColoTracer(Tracer):
|
|||||||
mod = self.root.get_submodule(target)
|
mod = self.root.get_submodule(target)
|
||||||
self.disable_module_getattr = True
|
self.disable_module_getattr = True
|
||||||
try:
|
try:
|
||||||
proxy.meta_data = self._custom_leaf_module_impl.get(type(mod),
|
args = tree_map(unwrap_fn, args)
|
||||||
mod.forward)(*tree_map(unwrap_fn, args),
|
kwargs = tree_map(unwrap_fn, kwargs)
|
||||||
**tree_map(unwrap_fn, kwargs))
|
if type(mod) in self._custom_leaf_module:
|
||||||
|
target = self._custom_leaf_module_impl[type(mod)]
|
||||||
|
proxy.meta_data = target(mod, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
proxy.meta_data = mod.forward(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
self.disable_module_getattr = False
|
self.disable_module_getattr = False
|
||||||
return proxy
|
return proxy
|
||||||
@ -320,15 +160,21 @@ class ColoTracer(Tracer):
|
|||||||
|
|
||||||
def trace(self,
|
def trace(self,
|
||||||
root: torch.nn.Module,
|
root: torch.nn.Module,
|
||||||
concrete_args: Optional[Dict[str, torch.Tensor]] = {},
|
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
meta_args: Optional[Dict[str, torch.Tensor]] = {}) -> Graph:
|
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
|
||||||
|
|
||||||
|
if meta_args is None:
|
||||||
|
meta_args = {}
|
||||||
|
|
||||||
|
if concrete_args is None:
|
||||||
|
concrete_args = {}
|
||||||
|
|
||||||
# check concrete and meta args have valid names
|
# check concrete and meta args have valid names
|
||||||
sig = inspect.signature(root.forward)
|
sig = inspect.signature(root.forward)
|
||||||
sig_names = set(sig.parameters.keys())
|
sig_names = set(sig.parameters.keys())
|
||||||
meta_arg_names = set(meta_args.keys())
|
meta_arg_names = set(meta_args.keys())
|
||||||
concrete_arg_names = set(concrete_args.keys())
|
concrete_arg_names = set(concrete_args.keys())
|
||||||
|
non_concrete_arg_names = sig_names - concrete_arg_names
|
||||||
# update concrete args with default values
|
# update concrete args with default values
|
||||||
for k, v in sig.parameters.items():
|
for k, v in sig.parameters.items():
|
||||||
if k in sig_names - meta_arg_names and \
|
if k in sig_names - meta_arg_names and \
|
||||||
@ -352,6 +198,34 @@ class ColoTracer(Tracer):
|
|||||||
self.graph = super().trace(root, concrete_args=concrete_args)
|
self.graph = super().trace(root, concrete_args=concrete_args)
|
||||||
self.mod_dir = ''
|
self.mod_dir = ''
|
||||||
self.graph.lint()
|
self.graph.lint()
|
||||||
|
|
||||||
|
for node in self.graph.nodes:
|
||||||
|
if node.op == "placeholder":
|
||||||
|
# Removing default values for inputs as the forward pass will fail with them.
|
||||||
|
if node.target in non_concrete_arg_names:
|
||||||
|
node.args = ()
|
||||||
|
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
|
||||||
|
# It cannot infer on the attributes and methods the input should have, and fails.
|
||||||
|
node.type = torch.Tensor
|
||||||
|
# It is a concrete arg so it is not used and should be removed.
|
||||||
|
else:
|
||||||
|
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
|
||||||
|
# Newer versions of torch.fx emit an assert statement
|
||||||
|
# for concrete arguments; delete those before we delete
|
||||||
|
# the concrete arg.
|
||||||
|
to_delete = []
|
||||||
|
for user in node.users:
|
||||||
|
if user.target == torch.fx._symbolic_trace._assert_is_none:
|
||||||
|
to_delete.append(user)
|
||||||
|
for user in to_delete:
|
||||||
|
self.graph.erase_node(user)
|
||||||
|
|
||||||
|
self.graph.erase_node(node)
|
||||||
|
|
||||||
|
# TODO: solves GraphModule creation.
|
||||||
|
# Without this, return type annotation "Tuple" is causing code execution failure.
|
||||||
|
if node.op == "output":
|
||||||
|
node.type = None
|
||||||
return self.graph
|
return self.graph
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -487,134 +361,3 @@ class ColoTracer(Tracer):
|
|||||||
return maybe_parameter_proxy
|
return maybe_parameter_proxy
|
||||||
|
|
||||||
return attr_val
|
return attr_val
|
||||||
|
|
||||||
|
|
||||||
def symbolic_trace(
|
|
||||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
|
||||||
concrete_args: Optional[Dict[str, Any]] = {},
|
|
||||||
meta_args: Optional[Dict[str, Any]] = {},
|
|
||||||
trace_act_ckpt: bool = False,
|
|
||||||
bias_addition_split: bool = False,
|
|
||||||
) -> ColoGraphModule:
|
|
||||||
"""
|
|
||||||
Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
|
|
||||||
attached to the ``Node``s.
|
|
||||||
|
|
||||||
Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
|
|
||||||
(https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
|
|
||||||
|
|
||||||
This tracer is able to trace basic control flow and for loops.
|
|
||||||
|
|
||||||
It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
|
|
||||||
(See ./bias_addition.py for more details).
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
1. Tracing a ``torch.nn.Module`` with control flow.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class MyModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.linear = torch.nn.Linear(2, 2)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if x.size(0) > 1:
|
|
||||||
x = x.sum(dim=0)
|
|
||||||
return self.linear(x)
|
|
||||||
|
|
||||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
|
|
||||||
|
|
||||||
# traced code like:
|
|
||||||
# def forward(self, x):
|
|
||||||
# linear_1 = self.linear(x)
|
|
||||||
# return linear_1
|
|
||||||
|
|
||||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
|
|
||||||
|
|
||||||
# traced code like:
|
|
||||||
# def forward(self, x):
|
|
||||||
# sum = x.sum(dim=0); x = None
|
|
||||||
# linear = self.linear(sum); sum = None
|
|
||||||
# return linear
|
|
||||||
|
|
||||||
2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class MyModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.linear = torch.nn.Linear(2, 2)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
def custom_forward(x):
|
|
||||||
return self.linear(x)
|
|
||||||
return torch.utils.checkpoint.checkpoint(custom_forward, x)
|
|
||||||
|
|
||||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
|
|
||||||
|
|
||||||
# traced code like:
|
|
||||||
# def checkpoint_0(self, x):
|
|
||||||
# linear = self.linear(x); x = None
|
|
||||||
# return linear
|
|
||||||
#
|
|
||||||
# def forward(self, x):
|
|
||||||
# linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
|
|
||||||
# return linear
|
|
||||||
|
|
||||||
3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class MyModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.linear = torch.nn.Linear(2, 2, bias=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.linear(x)
|
|
||||||
|
|
||||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
|
|
||||||
|
|
||||||
# traced code like:
|
|
||||||
# def forward(self, x):
|
|
||||||
# linear_bias = self.linear.bias
|
|
||||||
# linear_weight = self.linear.weight
|
|
||||||
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
|
|
||||||
# add = linear + linear_bias; linear = linear_bias = None
|
|
||||||
# return add
|
|
||||||
|
|
||||||
Args:
|
|
||||||
root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
|
|
||||||
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
|
|
||||||
Defaults to {}.
|
|
||||||
meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
|
|
||||||
for tracing control flow. Defaults to {}.
|
|
||||||
trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
|
|
||||||
Defaults to False.
|
|
||||||
bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
|
|
||||||
|
|
||||||
Remarks:
|
|
||||||
This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
|
|
||||||
any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
|
|
||||||
repo. We welcome any feedback and contributions to enhance the extensibility of
|
|
||||||
Colossal-AI.
|
|
||||||
"""
|
|
||||||
if meta_args:
|
|
||||||
device, orig_device = _default_device(), _current_device(root)
|
|
||||||
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
|
|
||||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
|
|
||||||
bias_addition_split=bias_addition_split).trace(root.to(device),
|
|
||||||
concrete_args=concrete_args,
|
|
||||||
meta_args=tree_map(wrap_fn, meta_args))
|
|
||||||
if trace_act_ckpt:
|
|
||||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
|
||||||
root.to(orig_device)
|
|
||||||
else:
|
|
||||||
graph = Tracer().trace(root, concrete_args=concrete_args)
|
|
||||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
|
||||||
return ColoGraphModule(root, graph, name)
|
|
@ -1,5 +1,4 @@
|
|||||||
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
|
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
|
||||||
|
|
||||||
from .registry import model_zoo
|
from .registry import model_zoo
|
||||||
|
|
||||||
__all__ = ['model_zoo']
|
__all__ = ['model_zoo']
|
||||||
|
@ -17,6 +17,14 @@ def data_gen():
|
|||||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def seq_classification_data_gen():
|
||||||
|
# batch sizes should be 1 if no padding token is defined.
|
||||||
|
input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
output_transform_fn = lambda x: x
|
output_transform_fn = lambda x: x
|
||||||
|
|
||||||
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
|
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
|
||||||
@ -44,6 +52,6 @@ model_zoo.register(name='transformers_gpt_for_token_classification',
|
|||||||
model_attribute=ModelAttribute(has_control_flow=True))
|
model_attribute=ModelAttribute(has_control_flow=True))
|
||||||
model_zoo.register(name='transformers_gpt_for_sequence_classification',
|
model_zoo.register(name='transformers_gpt_for_sequence_classification',
|
||||||
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
|
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
|
||||||
data_gen_fn=data_gen,
|
data_gen_fn=seq_classification_data_gen,
|
||||||
output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
model_attribute=ModelAttribute(has_control_flow=True))
|
model_attribute=ModelAttribute(has_control_flow=True))
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -73,7 +74,7 @@ class AddmmModel(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.parametrize("bias", [True, False])
|
@pytest.mark.parametrize("bias", [True, False])
|
||||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
||||||
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||||
|
@ -3,7 +3,8 @@ from numpy import isin
|
|||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from torch.utils._pytree import tree_flatten
|
from torch.utils._pytree import tree_flatten
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
# from colossalai.fx import symbolic_trace
|
||||||
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
|
|
||||||
|
|
||||||
def trace_model_and_compare_output(model, data_gen):
|
def trace_model_and_compare_output(model, data_gen):
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
from hf_tracer_utils import trace_model_and_compare_output
|
from hf_tracer_utils import trace_model_and_compare_output
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
@ -6,6 +9,7 @@ BATCH_SIZE = 2
|
|||||||
SEQ_LENGTH = 16
|
SEQ_LENGTH = 16
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
def test_albert():
|
def test_albert():
|
||||||
sub_registry = model_zoo.get_sub_registry('transformers_albert')
|
sub_registry = model_zoo.get_sub_registry('transformers_albert')
|
||||||
|
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
from hf_tracer_utils import trace_model_and_compare_output
|
from hf_tracer_utils import trace_model_and_compare_output
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
def test_bert():
|
def test_bert():
|
||||||
sub_registry = model_zoo.get_sub_registry('transformers_bert')
|
sub_registry = model_zoo.get_sub_registry('transformers_bert')
|
||||||
|
|
||||||
|
@ -1,16 +1,24 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from hf_tracer_utils import trace_model_and_compare_output
|
from hf_tracer_utils import trace_model_and_compare_output
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
# TODO: remove this skip once we handle the latest gpt model
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.skip
|
|
||||||
def test_gpt():
|
def test_gpt():
|
||||||
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
|
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||||
model = model_fn()
|
model = model_fn()
|
||||||
|
|
||||||
|
# TODO: support the following models
|
||||||
|
# 1. GPT2DoubleHeadsModel
|
||||||
|
# as they are not supported, let's skip them
|
||||||
|
if model.__class__.__name__ in ['GPT2DoubleHeadsModel']:
|
||||||
|
continue
|
||||||
|
|
||||||
trace_model_and_compare_output(model, data_gen_fn)
|
trace_model_and_compare_output(model, data_gen_fn)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
from hf_tracer_utils import trace_model_and_compare_output
|
from hf_tracer_utils import trace_model_and_compare_output
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
def test_opt():
|
def test_opt():
|
||||||
sub_registry = model_zoo.get_sub_registry('transformers_opt')
|
sub_registry = model_zoo.get_sub_registry('transformers_opt')
|
||||||
|
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
from hf_tracer_utils import trace_model_and_compare_output
|
from hf_tracer_utils import trace_model_and_compare_output
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
def test_t5():
|
def test_t5():
|
||||||
sub_registry = model_zoo.get_sub_registry('transformers_t5')
|
sub_registry = model_zoo.get_sub_registry('transformers_t5')
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import timm.models as tm
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
@ -42,6 +42,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
|||||||
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
|
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
def test_timm_models():
|
def test_timm_models():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
@ -1,20 +1,18 @@
|
|||||||
import re
|
import pytest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torchaudio_utils import trace_and_compare
|
from torchaudio_utils import trace_and_compare
|
||||||
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
def test_torchaudio_models():
|
def test_torchaudio_models():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
|
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||||
# FIXME(ver217): temporarily skip these models
|
|
||||||
if re.search(f'(conformer|emformer|tacotron|wav2vec2_base|hubert_base)', name):
|
|
||||||
continue
|
|
||||||
model = model_fn()
|
model = model_fn()
|
||||||
trace_and_compare(model,
|
trace_and_compare(model,
|
||||||
data_gen_fn,
|
data_gen_fn,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
|
|
||||||
|
|
||||||
def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False):
|
def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
BATCH = 2
|
BATCH = 2
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
BATCH = 2
|
BATCH = 2
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user