diff --git a/colossalai/_analyzer/_subclasses/_meta_registration.py b/colossalai/_analyzer/_subclasses/_meta_registration.py index 4b1fd28e9..4049be79c 100644 --- a/colossalai/_analyzer/_subclasses/_meta_registration.py +++ b/colossalai/_analyzer/_subclasses/_meta_registration.py @@ -274,11 +274,15 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): aten.prelu.default, aten.hardswish.default, aten.hardtanh.default, - aten.prelu_backward.default, aten.hardswish_backward.default, aten.hardtanh_backward.default, ] + if version.parse(torch.__version__) < version.parse('2.0.0'): + _unregistered_ewise += [ + aten.prelu_backward.default, + ] + @register_meta(_unregistered_ewise) def meta_unregistered_ewise(input: torch.Tensor, *args): return new_like(input) @@ -331,11 +335,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride): return new_like(input) - # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml - @register_meta(aten.eye.m_out) - def meta_eye(n: int, m: int, out: torch.Tensor): - return out - # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml @register_meta(aten.roll.default) def meta_roll(input: torch.Tensor, shifts, dims): @@ -352,97 +351,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): result_type = torch.result_type(self, other) return new_like(condition + self + other, dtype=result_type) - @register_meta(aten.index.Tensor) - def meta_index_Tensor(self, indices): - assert indices, "at least one index must be provided" - # aten::index is the internal advanced indexing implementation - # checkIndexTensorTypes and expandTensors - result: List[Optional[torch.Tensor]] = [] - for i, index in enumerate(indices): - if index is not None: - assert index.dtype in [torch.long, torch.int8, torch.bool],\ - "tensors used as indices must be long, byte or bool tensors" - if index.dtype in [torch.int8, torch.bool]: - nonzero = index.nonzero() - k = len(result) - assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" - for j in range(index.ndim): - assert index.shape[j] == self.shape[ - k + - j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" - result.append(nonzero.select(1, j)) - else: - result.append(index) - else: - result.append(index) - indices = result - assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" - # expand_outplace - import torch._refs as refs - - indices = list(refs._maybe_broadcast(*indices)) - # add missing null tensors - while len(indices) < self.ndim: - indices.append(None) - - # hasContiguousSubspace - # true if all non-null tensors are adjacent - # See: - # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing - # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency - state = 0 - has_contiguous_subspace = False - for index in indices: - if state == 0: - if index is not None: - state = 1 - elif state == 1: - if index is None: - state = 2 - else: - if index is not None: - break - else: - has_contiguous_subspace = True - - # transposeToFront - # This is the logic that causes the newly inserted dimensions to show up - # at the beginning of the tensor, if they're not contiguous - if not has_contiguous_subspace: - dims = [] - transposed_indices = [] - for i, index in enumerate(indices): - if index is not None: - dims.append(i) - transposed_indices.append(index) - for i, index in enumerate(indices): - if index is None: - dims.append(i) - transposed_indices.append(index) - self = self.permute(dims) - indices = transposed_indices - - # AdvancedIndex::AdvancedIndex - # Now we can assume the indices have contiguous subspace - # This is simplified from AdvancedIndex which goes to more effort - # to put the input and indices in a form so that TensorIterator can - # take them. If we write a ref for this, probably that logic should - # get implemented - before_shape: List[int] = [] - after_shape: List[int] = [] - replacement_shape: List[int] = [] - for dim, index in enumerate(indices): - if index is None: - if replacement_shape: - after_shape.append(self.shape[dim]) - else: - before_shape.append(self.shape[dim]) - else: - replacement_shape = list(index.shape) - return self.new_empty(before_shape + replacement_shape + after_shape) - # ============================== Embedding ========================================= # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp + @register_meta(aten.embedding_dense_backward.default) def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq): @@ -459,3 +370,99 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): @register_meta(aten.native_dropout_backward.default) def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): return new_like(grad) # (grad_in) + + if version.parse(torch.__version__) < version.parse('1.13.0'): + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml + @register_meta(aten.eye.m_out) + def meta_eye(n: int, m: int, out: torch.Tensor): + return out + + @register_meta(aten.index.Tensor) + def meta_index_Tensor(self, indices): + assert indices, "at least one index must be provided" + # aten::index is the internal advanced indexing implementation + # checkIndexTensorTypes and expandTensors + result: List[Optional[torch.Tensor]] = [] + for i, index in enumerate(indices): + if index is not None: + assert index.dtype in [torch.long, torch.int8, torch.bool],\ + "tensors used as indices must be long, byte or bool tensors" + if index.dtype in [torch.int8, torch.bool]: + nonzero = index.nonzero() + k = len(result) + assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" + for j in range(index.ndim): + assert index.shape[j] == self.shape[ + k + + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + result.append(nonzero.select(1, j)) + else: + result.append(index) + else: + result.append(index) + indices = result + assert len( + indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" + # expand_outplace + import torch._refs as refs + + indices = list(refs._maybe_broadcast(*indices)) + # add missing null tensors + while len(indices) < self.ndim: + indices.append(None) + + # hasContiguousSubspace + # true if all non-null tensors are adjacent + # See: + # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency + state = 0 + has_contiguous_subspace = False + for index in indices: + if state == 0: + if index is not None: + state = 1 + elif state == 1: + if index is None: + state = 2 + else: + if index is not None: + break + else: + has_contiguous_subspace = True + + # transposeToFront + # This is the logic that causes the newly inserted dimensions to show up + # at the beginning of the tensor, if they're not contiguous + if not has_contiguous_subspace: + dims = [] + transposed_indices = [] + for i, index in enumerate(indices): + if index is not None: + dims.append(i) + transposed_indices.append(index) + for i, index in enumerate(indices): + if index is None: + dims.append(i) + transposed_indices.append(index) + self = self.permute(dims) + indices = transposed_indices + + # AdvancedIndex::AdvancedIndex + # Now we can assume the indices have contiguous subspace + # This is simplified from AdvancedIndex which goes to more effort + # to put the input and indices in a form so that TensorIterator can + # take them. If we write a ref for this, probably that logic should + # get implemented + before_shape: List[int] = [] + after_shape: List[int] = [] + replacement_shape: List[int] = [] + for dim, index in enumerate(indices): + if index is None: + if replacement_shape: + after_shape.append(self.shape[dim]) + else: + before_shape.append(self.shape[dim]) + else: + replacement_shape = list(index.shape) + return self.new_empty(before_shape + replacement_shape + after_shape)