mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 12:12:46 +00:00
[fx] fix meta tensor registration (#3589)
* [meta] fix torch 1.13.1 * [meta] fix torch 2.0.0 * [meta] fix torch 1.13.0 * [meta] polish code
This commit is contained in:
parent
36a519b49f
commit
dac127d0ee
@ -274,11 +274,15 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||||||
aten.prelu.default,
|
aten.prelu.default,
|
||||||
aten.hardswish.default,
|
aten.hardswish.default,
|
||||||
aten.hardtanh.default,
|
aten.hardtanh.default,
|
||||||
aten.prelu_backward.default,
|
|
||||||
aten.hardswish_backward.default,
|
aten.hardswish_backward.default,
|
||||||
aten.hardtanh_backward.default,
|
aten.hardtanh_backward.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) < version.parse('2.0.0'):
|
||||||
|
_unregistered_ewise += [
|
||||||
|
aten.prelu_backward.default,
|
||||||
|
]
|
||||||
|
|
||||||
@register_meta(_unregistered_ewise)
|
@register_meta(_unregistered_ewise)
|
||||||
def meta_unregistered_ewise(input: torch.Tensor, *args):
|
def meta_unregistered_ewise(input: torch.Tensor, *args):
|
||||||
return new_like(input)
|
return new_like(input)
|
||||||
@ -331,11 +335,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||||||
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
|
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
|
||||||
return new_like(input)
|
return new_like(input)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
|
||||||
@register_meta(aten.eye.m_out)
|
|
||||||
def meta_eye(n: int, m: int, out: torch.Tensor):
|
|
||||||
return out
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||||
@register_meta(aten.roll.default)
|
@register_meta(aten.roll.default)
|
||||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||||
@ -352,97 +351,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||||||
result_type = torch.result_type(self, other)
|
result_type = torch.result_type(self, other)
|
||||||
return new_like(condition + self + other, dtype=result_type)
|
return new_like(condition + self + other, dtype=result_type)
|
||||||
|
|
||||||
@register_meta(aten.index.Tensor)
|
|
||||||
def meta_index_Tensor(self, indices):
|
|
||||||
assert indices, "at least one index must be provided"
|
|
||||||
# aten::index is the internal advanced indexing implementation
|
|
||||||
# checkIndexTensorTypes and expandTensors
|
|
||||||
result: List[Optional[torch.Tensor]] = []
|
|
||||||
for i, index in enumerate(indices):
|
|
||||||
if index is not None:
|
|
||||||
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
|
||||||
"tensors used as indices must be long, byte or bool tensors"
|
|
||||||
if index.dtype in [torch.int8, torch.bool]:
|
|
||||||
nonzero = index.nonzero()
|
|
||||||
k = len(result)
|
|
||||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
|
||||||
for j in range(index.ndim):
|
|
||||||
assert index.shape[j] == self.shape[
|
|
||||||
k +
|
|
||||||
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
|
||||||
result.append(nonzero.select(1, j))
|
|
||||||
else:
|
|
||||||
result.append(index)
|
|
||||||
else:
|
|
||||||
result.append(index)
|
|
||||||
indices = result
|
|
||||||
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
|
||||||
# expand_outplace
|
|
||||||
import torch._refs as refs
|
|
||||||
|
|
||||||
indices = list(refs._maybe_broadcast(*indices))
|
|
||||||
# add missing null tensors
|
|
||||||
while len(indices) < self.ndim:
|
|
||||||
indices.append(None)
|
|
||||||
|
|
||||||
# hasContiguousSubspace
|
|
||||||
# true if all non-null tensors are adjacent
|
|
||||||
# See:
|
|
||||||
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
|
||||||
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
|
|
||||||
state = 0
|
|
||||||
has_contiguous_subspace = False
|
|
||||||
for index in indices:
|
|
||||||
if state == 0:
|
|
||||||
if index is not None:
|
|
||||||
state = 1
|
|
||||||
elif state == 1:
|
|
||||||
if index is None:
|
|
||||||
state = 2
|
|
||||||
else:
|
|
||||||
if index is not None:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
has_contiguous_subspace = True
|
|
||||||
|
|
||||||
# transposeToFront
|
|
||||||
# This is the logic that causes the newly inserted dimensions to show up
|
|
||||||
# at the beginning of the tensor, if they're not contiguous
|
|
||||||
if not has_contiguous_subspace:
|
|
||||||
dims = []
|
|
||||||
transposed_indices = []
|
|
||||||
for i, index in enumerate(indices):
|
|
||||||
if index is not None:
|
|
||||||
dims.append(i)
|
|
||||||
transposed_indices.append(index)
|
|
||||||
for i, index in enumerate(indices):
|
|
||||||
if index is None:
|
|
||||||
dims.append(i)
|
|
||||||
transposed_indices.append(index)
|
|
||||||
self = self.permute(dims)
|
|
||||||
indices = transposed_indices
|
|
||||||
|
|
||||||
# AdvancedIndex::AdvancedIndex
|
|
||||||
# Now we can assume the indices have contiguous subspace
|
|
||||||
# This is simplified from AdvancedIndex which goes to more effort
|
|
||||||
# to put the input and indices in a form so that TensorIterator can
|
|
||||||
# take them. If we write a ref for this, probably that logic should
|
|
||||||
# get implemented
|
|
||||||
before_shape: List[int] = []
|
|
||||||
after_shape: List[int] = []
|
|
||||||
replacement_shape: List[int] = []
|
|
||||||
for dim, index in enumerate(indices):
|
|
||||||
if index is None:
|
|
||||||
if replacement_shape:
|
|
||||||
after_shape.append(self.shape[dim])
|
|
||||||
else:
|
|
||||||
before_shape.append(self.shape[dim])
|
|
||||||
else:
|
|
||||||
replacement_shape = list(index.shape)
|
|
||||||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
|
||||||
|
|
||||||
# ============================== Embedding =========================================
|
# ============================== Embedding =========================================
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||||
|
|
||||||
@register_meta(aten.embedding_dense_backward.default)
|
@register_meta(aten.embedding_dense_backward.default)
|
||||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||||
scale_grad_by_freq):
|
scale_grad_by_freq):
|
||||||
@ -459,3 +370,99 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||||||
@register_meta(aten.native_dropout_backward.default)
|
@register_meta(aten.native_dropout_backward.default)
|
||||||
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||||
return new_like(grad) # (grad_in)
|
return new_like(grad) # (grad_in)
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) < version.parse('1.13.0'):
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||||
|
@register_meta(aten.eye.m_out)
|
||||||
|
def meta_eye(n: int, m: int, out: torch.Tensor):
|
||||||
|
return out
|
||||||
|
|
||||||
|
@register_meta(aten.index.Tensor)
|
||||||
|
def meta_index_Tensor(self, indices):
|
||||||
|
assert indices, "at least one index must be provided"
|
||||||
|
# aten::index is the internal advanced indexing implementation
|
||||||
|
# checkIndexTensorTypes and expandTensors
|
||||||
|
result: List[Optional[torch.Tensor]] = []
|
||||||
|
for i, index in enumerate(indices):
|
||||||
|
if index is not None:
|
||||||
|
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
||||||
|
"tensors used as indices must be long, byte or bool tensors"
|
||||||
|
if index.dtype in [torch.int8, torch.bool]:
|
||||||
|
nonzero = index.nonzero()
|
||||||
|
k = len(result)
|
||||||
|
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||||
|
for j in range(index.ndim):
|
||||||
|
assert index.shape[j] == self.shape[
|
||||||
|
k +
|
||||||
|
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||||
|
result.append(nonzero.select(1, j))
|
||||||
|
else:
|
||||||
|
result.append(index)
|
||||||
|
else:
|
||||||
|
result.append(index)
|
||||||
|
indices = result
|
||||||
|
assert len(
|
||||||
|
indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||||
|
# expand_outplace
|
||||||
|
import torch._refs as refs
|
||||||
|
|
||||||
|
indices = list(refs._maybe_broadcast(*indices))
|
||||||
|
# add missing null tensors
|
||||||
|
while len(indices) < self.ndim:
|
||||||
|
indices.append(None)
|
||||||
|
|
||||||
|
# hasContiguousSubspace
|
||||||
|
# true if all non-null tensors are adjacent
|
||||||
|
# See:
|
||||||
|
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
||||||
|
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
|
||||||
|
state = 0
|
||||||
|
has_contiguous_subspace = False
|
||||||
|
for index in indices:
|
||||||
|
if state == 0:
|
||||||
|
if index is not None:
|
||||||
|
state = 1
|
||||||
|
elif state == 1:
|
||||||
|
if index is None:
|
||||||
|
state = 2
|
||||||
|
else:
|
||||||
|
if index is not None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
has_contiguous_subspace = True
|
||||||
|
|
||||||
|
# transposeToFront
|
||||||
|
# This is the logic that causes the newly inserted dimensions to show up
|
||||||
|
# at the beginning of the tensor, if they're not contiguous
|
||||||
|
if not has_contiguous_subspace:
|
||||||
|
dims = []
|
||||||
|
transposed_indices = []
|
||||||
|
for i, index in enumerate(indices):
|
||||||
|
if index is not None:
|
||||||
|
dims.append(i)
|
||||||
|
transposed_indices.append(index)
|
||||||
|
for i, index in enumerate(indices):
|
||||||
|
if index is None:
|
||||||
|
dims.append(i)
|
||||||
|
transposed_indices.append(index)
|
||||||
|
self = self.permute(dims)
|
||||||
|
indices = transposed_indices
|
||||||
|
|
||||||
|
# AdvancedIndex::AdvancedIndex
|
||||||
|
# Now we can assume the indices have contiguous subspace
|
||||||
|
# This is simplified from AdvancedIndex which goes to more effort
|
||||||
|
# to put the input and indices in a form so that TensorIterator can
|
||||||
|
# take them. If we write a ref for this, probably that logic should
|
||||||
|
# get implemented
|
||||||
|
before_shape: List[int] = []
|
||||||
|
after_shape: List[int] = []
|
||||||
|
replacement_shape: List[int] = []
|
||||||
|
for dim, index in enumerate(indices):
|
||||||
|
if index is None:
|
||||||
|
if replacement_shape:
|
||||||
|
after_shape.append(self.shape[dim])
|
||||||
|
else:
|
||||||
|
before_shape.append(self.shape[dim])
|
||||||
|
else:
|
||||||
|
replacement_shape = list(index.shape)
|
||||||
|
return self.new_empty(before_shape + replacement_shape + after_shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user