mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[fx] added torchvision model tracing testing (#1216)
* [fx] added torchvision model tracing testing * remove unused imports
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from curses import meta
|
||||
import operator
|
||||
import torch
|
||||
from .registry import meta_patched_function
|
||||
@@ -142,3 +141,40 @@ def torch_bmm(input, mat2, *, out=None):
|
||||
batch_size, n, m = input.shape
|
||||
_, _, p = mat2.shape
|
||||
return torch.empty(batch_size, n, p, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.squeeze)
|
||||
def torch_squeeze(input, dim=None):
|
||||
shape = list(input.shape)
|
||||
if dim is not None:
|
||||
if dim < 0:
|
||||
dim = input.dim() + dim
|
||||
if shape[dim] == 1:
|
||||
shape.pop(dim)
|
||||
else:
|
||||
new_shape = []
|
||||
for dim_value in shape:
|
||||
if dim_value == 1:
|
||||
continue
|
||||
new_shape.append(dim_value)
|
||||
shape = new_shape
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.squeeze)
|
||||
def torch_tensor_squeeze(self, dim=None):
|
||||
return torch_squeeze(self, dim)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.unsqueeze)
|
||||
def torch_unsqueeze(input, dim):
|
||||
shape = list(input.shape)
|
||||
if dim < 0:
|
||||
dim = input.dim() + 1 + dim
|
||||
shape.insert(dim, 1)
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.unsqueeze)
|
||||
def torch_tensor_unsqueeze(self, dim):
|
||||
return torch_unsqueeze(self, dim)
|
||||
|
@@ -88,6 +88,137 @@ def torch_nn_conv3d(self, input):
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.AvgPool1d)
|
||||
def torch_nn_avgpool1d(self, input):
|
||||
num_dim = input.dim()
|
||||
assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions'
|
||||
|
||||
l_in = input.shape[-1]
|
||||
|
||||
def _convert_int_to_list(item):
|
||||
if isinstance(item, int):
|
||||
return [item] * 1
|
||||
else:
|
||||
return item
|
||||
|
||||
padding = _convert_int_to_list(self.padding)
|
||||
kernel_size = _convert_int_to_list(self.kernel_size)
|
||||
stride = _convert_int_to_list(self.stride)
|
||||
|
||||
l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
|
||||
|
||||
result_shape = input.shape[:-1] + (l_out,)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.AvgPool2d)
|
||||
def torch_nn_avgpool2d(self, input):
|
||||
num_dim = input.dim()
|
||||
assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions'
|
||||
|
||||
h_in, w_in = input.shape[-2:]
|
||||
|
||||
def _convert_int_to_list(item):
|
||||
if isinstance(item, int):
|
||||
return [item] * 2
|
||||
else:
|
||||
return item
|
||||
|
||||
padding = _convert_int_to_list(self.padding)
|
||||
kernel_size = _convert_int_to_list(self.kernel_size)
|
||||
stride = _convert_int_to_list(self.stride)
|
||||
|
||||
h_out = math.floor((h_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
|
||||
w_out = math.floor((w_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
|
||||
|
||||
result_shape = input.shape[:-2] + (
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.AvgPool3d)
|
||||
def torch_nn_avgpool3d(self, input):
|
||||
num_dim = input.dim()
|
||||
assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions'
|
||||
|
||||
d_in, h_in, w_in = input.shape[-3:]
|
||||
|
||||
def _convert_int_to_list(item):
|
||||
if isinstance(item, int):
|
||||
return [item] * 3
|
||||
else:
|
||||
return item
|
||||
|
||||
padding = _convert_int_to_list(self.padding)
|
||||
kernel_size = _convert_int_to_list(self.kernel_size)
|
||||
stride = _convert_int_to_list(self.stride)
|
||||
|
||||
d_out = math.floor((d_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
|
||||
h_out = math.floor((h_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
|
||||
w_out = math.floor((w_in + 2 * padding[2] - kernel_size[2]) / stride[2] + 1)
|
||||
|
||||
result_shape = input.shape[:-3] + (
|
||||
d_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.MaxPool1d)
|
||||
def torch_nn_maxpool1d(self, input):
|
||||
num_dim = input.dim()
|
||||
assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions'
|
||||
|
||||
l_in = input.shape[-1]
|
||||
|
||||
def _convert_int_to_list(item):
|
||||
if isinstance(item, int):
|
||||
return [item] * 1
|
||||
else:
|
||||
return item
|
||||
|
||||
padding = _convert_int_to_list(self.padding)
|
||||
dilation = _convert_int_to_list(self.dilation)
|
||||
kernel_size = _convert_int_to_list(self.kernel_size)
|
||||
stride = _convert_int_to_list(self.stride)
|
||||
|
||||
l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
|
||||
|
||||
result_shape = input.shape[:-1] + (l_out,)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.MaxPool2d)
|
||||
def torch_nn_maxpool2d(self, input):
|
||||
num_dim = input.dim()
|
||||
assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions'
|
||||
|
||||
h_in, w_in = input.shape[-2:]
|
||||
|
||||
def _convert_int_to_list(item):
|
||||
if isinstance(item, int):
|
||||
return [item] * 2
|
||||
else:
|
||||
return item
|
||||
|
||||
padding = _convert_int_to_list(self.padding)
|
||||
dilation = _convert_int_to_list(self.dilation)
|
||||
kernel_size = _convert_int_to_list(self.kernel_size)
|
||||
stride = _convert_int_to_list(self.stride)
|
||||
|
||||
h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
|
||||
w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
|
||||
|
||||
result_shape = input.shape[:-2] + (
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.MaxPool3d)
|
||||
def torch_nn_maxpool3d(self, input):
|
||||
num_dim = input.dim()
|
||||
|
@@ -4,9 +4,8 @@ tracer.py:
|
||||
Implemented a tracer which supports control flow and user-defined meta arguments.
|
||||
The implementation is partly inspired HuggingFace's fx tracer
|
||||
"""
|
||||
|
||||
import enum
|
||||
import inspect
|
||||
import math
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -22,6 +21,11 @@ from .meta_patch import meta_patched_function, meta_patched_module
|
||||
__all__ = ['ColoTracer']
|
||||
|
||||
|
||||
class TracerType(enum.Enum):
|
||||
DEFAULT = 1
|
||||
META = 2
|
||||
|
||||
|
||||
class ColoTracer(Tracer):
|
||||
"""
|
||||
ColoTracer is a symbolic tracer designed to support dynamic control flow by using meta tensors for the `colossalai.fx` module.
|
||||
@@ -48,6 +52,11 @@ class ColoTracer(Tracer):
|
||||
graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')})
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tracer_type = TracerType.META
|
||||
self.proxy_cls = ColoProxy
|
||||
|
||||
# Feature flag for proxying accesses to buffer values
|
||||
proxy_buffer_attributes: bool = True
|
||||
|
||||
@@ -58,6 +67,12 @@ class ColoTracer(Tracer):
|
||||
Create a proxy for different kinds of operations.
|
||||
"""
|
||||
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
||||
|
||||
if self.tracer_type == TracerType.DEFAULT:
|
||||
# since meta_args is not given
|
||||
# we just fall back to the original torch.fx.Tracer
|
||||
return proxy
|
||||
|
||||
proxy: ColoProxy
|
||||
|
||||
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
|
||||
@@ -168,11 +183,21 @@ class ColoTracer(Tracer):
|
||||
self.orig_forward = forward
|
||||
return super().call_module(m, forward, args, kwargs)
|
||||
|
||||
def proxy(self, node) -> ColoProxy:
|
||||
def proxy(self, node) -> Proxy:
|
||||
"""
|
||||
Returns a ColoProxy object.
|
||||
"""
|
||||
return ColoProxy(node, self)
|
||||
return self.proxy_cls(node, self)
|
||||
|
||||
def _configure_tracer_type(self, tracer_type: TracerType):
|
||||
if tracer_type == TracerType.DEFAULT:
|
||||
self.proxy_cls = Proxy
|
||||
self.tracer_type = TracerType.DEFAULT
|
||||
elif tracer_type == TracerType.META:
|
||||
self.proxy_cls = ColoProxy
|
||||
self.tracer_type = TracerType.META
|
||||
else:
|
||||
raise ValueError(f"Unrecognised tracer type {tracer_type}")
|
||||
|
||||
def trace(self,
|
||||
root: nn.Module,
|
||||
@@ -193,6 +218,11 @@ class ColoTracer(Tracer):
|
||||
if concrete_args is None:
|
||||
concrete_args = {}
|
||||
|
||||
if len(meta_args) == 0:
|
||||
self._configure_tracer_type(TracerType.DEFAULT)
|
||||
else:
|
||||
self._configure_tracer_type(TracerType.META)
|
||||
|
||||
# check concrete and meta args have valid names
|
||||
sig = inspect.signature(root.forward)
|
||||
sig_names = set(sig.parameters.keys())
|
||||
@@ -235,18 +265,21 @@ class ColoTracer(Tracer):
|
||||
self.concrete_args = concrete_args
|
||||
self.meta_args = meta_args
|
||||
|
||||
# wrap the torch tensor constructing methods so that they are captured in the graph
|
||||
self.patched_torch_tensor_methods = {
|
||||
target: wrap_tensor_constructor_method(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
|
||||
}
|
||||
self.patched_torch_tensor_methods = {}
|
||||
if self.tracer_type == TracerType.META:
|
||||
# wrap the torch tensor constructing methods so that they are captured in the graph
|
||||
self.patched_torch_tensor_methods = {
|
||||
target: wrap_tensor_constructor_method(getattr(torch, target))
|
||||
for target in self._TORCH_METHODS_TO_PATCH
|
||||
}
|
||||
|
||||
# patch these methods to replace their original use
|
||||
for name, (wrapper, orig) in self.patched_torch_tensor_methods.items():
|
||||
setattr(torch, name, wrapper)
|
||||
# patch these methods to replace their original use
|
||||
for name, (wrapper, orig) in self.patched_torch_tensor_methods.items():
|
||||
setattr(torch, name, wrapper)
|
||||
|
||||
# cache these methods so that we can detect whether a method call
|
||||
# should be patched during tracing
|
||||
self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()]
|
||||
# cache these methods so that we can detect whether a method call
|
||||
# should be patched during tracing
|
||||
self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()]
|
||||
|
||||
try:
|
||||
self.graph = super().trace(root, concrete_args=concrete_args)
|
||||
@@ -255,6 +288,9 @@ class ColoTracer(Tracer):
|
||||
for name, (_, orig) in self.patched_torch_tensor_methods.items():
|
||||
setattr(torch, name, orig)
|
||||
|
||||
if self.tracer_type == TracerType.DEFAULT:
|
||||
return self.graph
|
||||
|
||||
# This is necessary because concrete args are added as input to the traced module since
|
||||
# https://github.com/pytorch/pytorch/pull/55888.
|
||||
for node in self.graph.nodes:
|
||||
|
Reference in New Issue
Block a user