mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[fx] added timm model tracing testing (#1221)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from curses import meta
|
||||
import operator
|
||||
import torch
|
||||
from .registry import meta_patched_function
|
||||
@@ -99,7 +100,6 @@ def torch_abs(input, *, out=None):
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.relu)
|
||||
def torch_nn_func_relu(input, inplace=False):
|
||||
assert not inplace, 'inplace is not supported yet'
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
||||
|
||||
@@ -178,3 +178,43 @@ def torch_unsqueeze(input, dim):
|
||||
@meta_patched_function.register(torch.Tensor.unsqueeze)
|
||||
def torch_tensor_unsqueeze(self, dim):
|
||||
return torch_unsqueeze(self, dim)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.layer_norm)
|
||||
def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.batch_norm)
|
||||
def torch_nn_func_batchnorm(input,
|
||||
running_mean,
|
||||
running_var,
|
||||
weight=None,
|
||||
bias=None,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=1e-05):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.var_mean)
|
||||
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
|
||||
assert out is None, 'saving to out is not supported yet'
|
||||
var = torch.empty(1).squeeze(0).to('meta')
|
||||
mean = torch.empty(1).squeeze(0).to('meta')
|
||||
return var, mean
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.cat)
|
||||
def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
||||
if dim is None and axis is None:
|
||||
dim = 0
|
||||
if dim is None and axis is not None:
|
||||
dim = axis
|
||||
if dim < 0:
|
||||
dim = tensors[0].dim() + dim
|
||||
shapes = [t.shape for t in tensors]
|
||||
shape = list(shapes[0])
|
||||
concatenated_dim = sum(shape[dim] for shape in shapes)
|
||||
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:]
|
||||
return torch.empty(final_shape, device="meta")
|
||||
|
@@ -250,6 +250,6 @@ def torch_nn_maxpool3d(self, input):
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.ReLU)
|
||||
@meta_patched_module.register(torch.nn.ReLU6)
|
||||
def torch_nn_func_relu(self, input):
|
||||
assert not self.inplace, 'inplace is not supported yet'
|
||||
return input.clone()
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
Reference in New Issue
Block a user