mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 03:03:37 +00:00
[fx] refactored the file structure of patched function and module (#1238)
* [fx] refactored the file structure of patched function and module * polish code
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from ..registry import meta_patched_function
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.matmul)
|
||||
def torch_matmul(input, other, *, out=None):
|
||||
# copied from huggingface.utils.fx
|
||||
d1 = input.dim()
|
||||
d2 = other.dim()
|
||||
shape = None
|
||||
if d1 == 1 and d2 == 1:
|
||||
shape = None
|
||||
elif d1 == 2 and d2 == 2:
|
||||
shape = (input.size(0), other.size(1))
|
||||
elif d1 == 1 and d2 == 2:
|
||||
shape = (other.size(1),)
|
||||
elif d1 == 2 and d1 == 1:
|
||||
shape = (input.size(0),)
|
||||
else:
|
||||
max_length = max(input.dim(), other.dim())
|
||||
shape1 = list(input.shape)
|
||||
shape2 = list(other.shape)
|
||||
if d1 == 1:
|
||||
shape1 = [1] + shape1
|
||||
if d2 == 1:
|
||||
shape2.append(1)
|
||||
shape1 = [-1] * (max_length - d1) + list(input.shape)
|
||||
shape2 = [-1] * (max_length - d2) + list(other.shape)
|
||||
shape = []
|
||||
for i in range(max_length):
|
||||
shape.append(max(shape1[i], shape2[i]))
|
||||
shape[-2] = shape1[-2]
|
||||
shape[-1] = shape2[-1]
|
||||
if d1 == 1:
|
||||
shape.pop(-2)
|
||||
if d2 == 1:
|
||||
shape.pop(-1)
|
||||
if shape is None:
|
||||
return torch.tensor(0.0, device="meta")
|
||||
return torch.empty(*shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.abs)
|
||||
def torch_abs(input, *, out=None):
|
||||
assert out is None, 'out is not supported yet'
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if out is not None:
|
||||
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
||||
batch_size, n, m = input.shape
|
||||
_, _, p = mat2.shape
|
||||
return torch.empty(batch_size, n, p, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.var_mean)
|
||||
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
|
||||
assert out is None, 'saving to out is not supported yet'
|
||||
var = torch.empty(1).squeeze(0).to('meta')
|
||||
mean = torch.empty(1).squeeze(0).to('meta')
|
||||
return var, mean
|
Reference in New Issue
Block a user