diff --git a/colossalai/_analyzer/_subclasses/_meta_registration.py b/colossalai/_analyzer/_subclasses/_meta_registration.py index 20ab46054..2af7e0539 100644 --- a/colossalai/_analyzer/_subclasses/_meta_registration.py +++ b/colossalai/_analyzer/_subclasses/_meta_registration.py @@ -6,11 +6,15 @@ from typing import Callable, List, Optional, Tuple, Union import torch +from packaging import version from torch.utils._pytree import tree_map aten = torch.ops.aten -meta_lib = torch.library.Library("aten", "IMPL", "Meta") +try: + meta_lib = torch.library.Library("aten", "IMPL", "Meta") +except AttributeError: + meta_lib = None meta_table = {} @@ -50,432 +54,411 @@ def register_meta(op, register_dispatcher=True): return wrapper -# ============================== Convolutions ====================================== -# https://github.com/pytorch/pytorch/pull/79834 -@register_meta(aten.convolution.default) -def meta_conv( - input_tensor: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - stride: List[int], - padding: List[int], - dilation: List[int], - is_transposed: bool, - output_padding: List[int], - groups: int, -): - - def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: - """ - Formula to apply to calculate the length of some dimension of the output - See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html - Args: - ln: length of the dimension - p: padding in that dim - d: dilation in that dim - k: kernel size in that dim - s: stride in that dim - Returns: - The output length - """ - return (ln + 2 * p - d * (k - 1) - 1) // s + 1 - - def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: - """ - Formula to apply to calculate the length of some dimension of the output - if transposed convolution is used. - See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html - Args: - ln: length of the dimension - p: padding in that dim - d: dilation in that dim - k: kernel size in that dim - s: stride in that dim - op: output padding in that dim - Returns: - The output length - """ - return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 - - def calc_conv_nd_return_shape( - dims: torch.Size, - kernel_size: torch.Size, - stride: Union[List[int], int], - padding: Union[List[int], int], - dilation: Union[List[int], int], - output_padding: Optional[Union[List[int], int]] = None, +if version.parse(torch.__version__) >= version.parse('1.12.0'): + # ============================== Convolutions ====================================== + # https://github.com/pytorch/pytorch/pull/79834 + @register_meta(aten.convolution.default) + def meta_conv( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, ): - ret_shape = [] - if isinstance(stride, int): - stride = [stride] * len(dims) - elif len(stride) == 1: - stride = [stride[0]] * len(dims) - if isinstance(padding, int): - padding = [padding] * len(dims) - elif len(padding) == 1: - padding = [padding[0]] * len(dims) + def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + Returns: + The output length + """ + return (ln + 2 * p - d * (k - 1) - 1) // s + 1 - if isinstance(dilation, int): - dilation = [dilation] * len(dims) - elif len(dilation) == 1: - dilation = [dilation[0]] * len(dims) + def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + if transposed convolution is used. + See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + op: output padding in that dim + Returns: + The output length + """ + return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 - output_padding_list: Optional[List[int]] = None - if output_padding: - if isinstance(output_padding, int): - output_padding_list = [output_padding] * len(dims) - elif len(output_padding) == 1: - output_padding_list = [output_padding[0]] * len(dims) - else: - output_padding_list = output_padding + def calc_conv_nd_return_shape( + dims: torch.Size, + kernel_size: torch.Size, + stride: Union[List[int], int], + padding: Union[List[int], int], + dilation: Union[List[int], int], + output_padding: Optional[Union[List[int], int]] = None, + ): + ret_shape = [] + if isinstance(stride, int): + stride = [stride] * len(dims) + elif len(stride) == 1: + stride = [stride[0]] * len(dims) - for i in range(len(dims)): - # If output_padding is present, we are dealing with a transposed convolution - if output_padding_list: - ret_shape.append( - _formula_transposed( - dims[i], - padding[i], - dilation[i], - kernel_size[i], - stride[i], - output_padding_list[i], - )) - else: - ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) - return ret_shape + if isinstance(padding, int): + padding = [padding] * len(dims) + elif len(padding) == 1: + padding = [padding[0]] * len(dims) - def pick_memory_format(): - if input_tensor.is_contiguous(memory_format=torch.channels_last): - return torch.channels_last - elif input_tensor.is_contiguous(memory_format=torch.contiguous_format): - return torch.contiguous_format - elif input_tensor.is_contiguous(memory_format=torch.preserve_format): - return torch.preserve_format + if isinstance(dilation, int): + dilation = [dilation] * len(dims) + elif len(dilation) == 1: + dilation = [dilation[0]] * len(dims) - kernel_size = weight.shape[2:] - dims = input_tensor.shape[2:] - if is_transposed: - out_channels = groups * weight.shape[1] + output_padding_list: Optional[List[int]] = None + if output_padding: + if isinstance(output_padding, int): + output_padding_list = [output_padding] * len(dims) + elif len(output_padding) == 1: + output_padding_list = [output_padding[0]] * len(dims) + else: + output_padding_list = output_padding - shape_out = calc_conv_nd_return_shape( - dims, - kernel_size, - stride, - padding, - dilation, - output_padding, - ) + for i in range(len(dims)): + # If output_padding is present, we are dealing with a transposed convolution + if output_padding_list: + ret_shape.append( + _formula_transposed( + dims[i], + padding[i], + dilation[i], + kernel_size[i], + stride[i], + output_padding_list[i], + )) + else: + ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) + return ret_shape - else: - out_channels = weight.shape[0] - if weight.shape[1] != input_tensor.shape[1] / groups: - raise RuntimeError("Invalid channel dimensions") - shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) - out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) - mem_fmt = pick_memory_format() - out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] - return out + def pick_memory_format(): + if input_tensor.is_contiguous(memory_format=torch.channels_last): + return torch.channels_last + elif input_tensor.is_contiguous(memory_format=torch.contiguous_format): + return torch.contiguous_format + elif input_tensor.is_contiguous(memory_format=torch.preserve_format): + return torch.preserve_format + kernel_size = weight.shape[2:] + dims = input_tensor.shape[2:] + if is_transposed: + out_channels = groups * weight.shape[1] -@register_meta(aten._convolution.default) -def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], - padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, - *extra_args): - out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) - return out + shape_out = calc_conv_nd_return_shape( + dims, + kernel_size, + stride, + padding, + dilation, + output_padding, + ) + else: + out_channels = weight.shape[0] + if weight.shape[1] != input_tensor.shape[1] / groups: + raise RuntimeError("Invalid channel dimensions") + shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) + out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) + mem_fmt = pick_memory_format() + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + return out -@register_meta(aten.convolution_backward.default) -def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, - padding, dilation, transposed, output_padding, groups, output_mask): - return new_like(input), new_like(weight), new((bias_sizes)) + @register_meta(aten._convolution.default) + def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], + padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, + *extra_args): + out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) + return out + @register_meta(aten.convolution_backward.default) + def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, + padding, dilation, transposed, output_padding, groups, output_mask): + return new_like(input), new_like(weight), new((bias_sizes)) -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp -@register_meta(aten._adaptive_avg_pool2d_backward.default) -def meta_adaptive_avg_pool2d_backward( - grad_output: torch.Tensor, - input: torch.Tensor, -): - return new_like(input) + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp + @register_meta(aten._adaptive_avg_pool2d_backward.default) + def meta_adaptive_avg_pool2d_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + ): + return new_like(input) + # ================================ RNN ============================================= + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp + @register_meta(aten._cudnn_rnn.default) + def meta_cuda_rnn( + input, + weight, + weight_stride0, + weight_buf, + hx, + cx, + mode, + hidden_size, + proj_size, + num_layers, + batch_first, + dropout, + train, + bidirectional, + batch_sizes, + dropout_state, + ): -# ================================ RNN ============================================= -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp -@register_meta(aten._cudnn_rnn.default) -def meta_cuda_rnn( - input, - weight, - weight_stride0, - weight_buf, - hx, - cx, - mode, - hidden_size, - proj_size, - num_layers, - batch_first, - dropout, - train, - bidirectional, - batch_sizes, - dropout_state, -): + is_input_packed = len(batch_sizes) != 0 + if is_input_packed: + seq_length = len(batch_sizes) + mini_batch = batch_sizes[0] + batch_sizes_sum = input.shape[0] + else: + seq_length = input.shape[1] if batch_first else input.shape[0] + mini_batch = input.shape[0] if batch_first else input.shape[1] + batch_sizes_sum = -1 - is_input_packed = len(batch_sizes) != 0 - if is_input_packed: - seq_length = len(batch_sizes) - mini_batch = batch_sizes[0] - batch_sizes_sum = input.shape[0] - else: - seq_length = input.shape[1] if batch_first else input.shape[0] - mini_batch = input.shape[0] if batch_first else input.shape[1] - batch_sizes_sum = -1 + num_directions = 2 if bidirectional else 1 + out_size = proj_size if proj_size != 0 else hidden_size + if is_input_packed: + out_shape = [batch_sizes_sum, out_size * num_directions] + else: + out_shape = ([mini_batch, seq_length, out_size * + num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) + output = input.new_empty(out_shape) - num_directions = 2 if bidirectional else 1 - out_size = proj_size if proj_size != 0 else hidden_size - if is_input_packed: - out_shape = [batch_sizes_sum, out_size * num_directions] - else: - out_shape = ([mini_batch, seq_length, out_size * - num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) - output = input.new_empty(out_shape) + cell_shape = [num_layers * num_directions, mini_batch, hidden_size] + cy = new(0) if cx is None else cx.new_empty(cell_shape) - cell_shape = [num_layers * num_directions, mini_batch, hidden_size] - cy = new(0) if cx is None else cx.new_empty(cell_shape) + hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size]) - hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size]) + # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python) + reserve_shape = 0 if train else 0 + reserve = input.new_empty(reserve_shape, dtype=torch.uint8) - # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python) - reserve_shape = 0 if train else 0 - reserve = input.new_empty(reserve_shape, dtype=torch.uint8) + return output, hy, cy, reserve, weight_buf - return output, hy, cy, reserve, weight_buf + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp + @register_meta(aten._cudnn_rnn_backward.default) + def meta_cudnn_rnn_backward(input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs): + return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new( + ()) # (grad_input, grad_weight, grad_hx, grad_cx) + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp + # ============================== Activations ======================================= + _unregistered_ewise = [ + aten.relu.default, + aten.prelu.default, + aten.hardswish.default, + aten.hardtanh.default, + aten.prelu_backward.default, + aten.hardswish_backward.default, + aten.hardtanh_backward.default, + ] -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp -@register_meta(aten._cudnn_rnn_backward.default) -def meta_cudnn_rnn_backward(input: torch.Tensor, - weight: torch.Tensor, - weight_stride0: int, - hx: torch.Tensor, - cx: Optional[torch.Tensor] = None, - *args, - **kwargs): - return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new( - ()) # (grad_input, grad_weight, grad_hx, grad_cx) + @register_meta(_unregistered_ewise) + def meta_unregistered_ewise(input: torch.Tensor, *args): + return new_like(input) + # ============================== Normalization ===================================== + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp + @register_meta(aten.native_batch_norm.default) + def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): + n_input = input.size(1) + return new_like(input), new((n_input)), new((n_input)) -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp -# ============================== Activations ======================================= -_unregistered_ewise = [ - aten.relu.default, - aten.prelu.default, - aten.hardswish.default, - aten.hardtanh.default, - aten.prelu_backward.default, - aten.hardswish_backward.default, - aten.hardtanh_backward.default, -] + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp + @register_meta(aten.native_batch_norm_backward.default) + def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, + save_mean, save_invstd, train, eps, output_mask): + return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp + @register_meta(aten.cudnn_batch_norm.default) + def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): + n_input = input.size(1) + return new_like(input), new((n_input)), new((n_input)), new( + (0), dtype=torch.uint8) # (output, running_mean, running_var, reserve) -@register_meta(_unregistered_ewise) -def meta_unregistered_ewise(input: torch.Tensor, *args): - return new_like(input) + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp + # NB: CuDNN only implements the backward algorithm for batchnorm + # in training mode (evaluation mode batchnorm has a different algorithm), + # which is why this doesn't accept a 'training' parameter. + @register_meta(aten.cudnn_batch_norm_backward.default) + def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, + save_mean, save_invstd, eps, reserve): + return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp + @register_meta(aten.native_layer_norm.default) + def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): + bs, n_input = input.size(0), input.size(1) + return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var) -# ============================== Normalization ===================================== -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp -@register_meta(aten.native_batch_norm.default) -def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): - n_input = input.size(1) - return new_like(input), new((n_input)), new((n_input)) + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp + @register_meta(aten.native_layer_norm_backward.default) + def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, + grad_input_mask): + return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta) + # ================================== Misc ========================================== + # Maybe incorrect + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp + @register_meta(aten.im2col.default) + def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride): + return new_like(input) -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp -@register_meta(aten.native_batch_norm_backward.default) -def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, - save_invstd, train, eps, output_mask): - return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml + @register_meta(aten.eye.m_out) + def meta_eye(n: int, m: int, out: torch.Tensor): + return out + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml + @register_meta(aten.roll.default) + def meta_roll(input: torch.Tensor, shifts, dims): + return input -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp -@register_meta(aten.cudnn_batch_norm.default) -def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): - n_input = input.size(1) - return new_like(input), new((n_input)), new((n_input)), new( - (0), dtype=torch.uint8) # (output, running_mean, running_var, reserve) + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp + @register_meta(aten._local_scalar_dense.default) + def meta_local_scalar_dense(self: torch.Tensor): + return 0 + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp + @register_meta(aten.where.self) + def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): + result_type = torch.result_type(self, other) + return new_like(condition + self + other, dtype=result_type) -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp -# NB: CuDNN only implements the backward algorithm for batchnorm -# in training mode (evaluation mode batchnorm has a different algorithm), -# which is why this doesn't accept a 'training' parameter. -@register_meta(aten.cudnn_batch_norm_backward.default) -def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, - save_mean, save_invstd, eps, reserve): - return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) - - -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp -@register_meta(aten.native_layer_norm.default) -def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): - bs, n_input = input.size(0), input.size(1) - return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var) - - -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp -@register_meta(aten.native_layer_norm_backward.default) -def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, - grad_input_mask): - return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta) - - -# ================================== Misc ========================================== -# Maybe incorrect -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp -@register_meta(aten.im2col.default) -def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride): - return new_like(input) - - -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml -@register_meta(aten.eye.m_out) -def meta_eye(n: int, m: int, out: torch.Tensor): - return out - - -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml -@register_meta(aten.roll.default) -def meta_roll(input: torch.Tensor, shifts, dims): - return input - - -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp -@register_meta(aten._local_scalar_dense.default) -def meta_local_scalar_dense(self: torch.Tensor): - return 0 - - -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp -@register_meta(aten.where.self) -def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): - result_type = torch.result_type(self, other) - return new_like(condition + self + other, dtype=result_type) - - -@register_meta(aten.index.Tensor) -def meta_index_Tensor(self, indices): - assert indices, "at least one index must be provided" - # aten::index is the internal advanced indexing implementation - # checkIndexTensorTypes and expandTensors - result: List[Optional[torch.Tensor]] = [] - for i, index in enumerate(indices): - if index is not None: - assert index.dtype in [torch.long, torch.int8, torch.bool],\ - "tensors used as indices must be long, byte or bool tensors" - if index.dtype in [torch.int8, torch.bool]: - nonzero = index.nonzero() - k = len(result) - assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" - for j in range(index.ndim): - assert index.shape[j] == self.shape[ - k + - j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" - result.append(nonzero.select(1, j)) + @register_meta(aten.index.Tensor) + def meta_index_Tensor(self, indices): + assert indices, "at least one index must be provided" + # aten::index is the internal advanced indexing implementation + # checkIndexTensorTypes and expandTensors + result: List[Optional[torch.Tensor]] = [] + for i, index in enumerate(indices): + if index is not None: + assert index.dtype in [torch.long, torch.int8, torch.bool],\ + "tensors used as indices must be long, byte or bool tensors" + if index.dtype in [torch.int8, torch.bool]: + nonzero = index.nonzero() + k = len(result) + assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" + for j in range(index.ndim): + assert index.shape[j] == self.shape[ + k + + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + result.append(nonzero.select(1, j)) + else: + result.append(index) else: result.append(index) - else: - result.append(index) - indices = result - assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" - # expand_outplace - import torch._refs as refs + indices = result + assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" + # expand_outplace + import torch._refs as refs - indices = list(refs._maybe_broadcast(*indices)) - # add missing null tensors - while len(indices) < self.ndim: - indices.append(None) + indices = list(refs._maybe_broadcast(*indices)) + # add missing null tensors + while len(indices) < self.ndim: + indices.append(None) - # hasContiguousSubspace - # true if all non-null tensors are adjacent - # See: - # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing - # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency - state = 0 - has_contiguous_subspace = False - for index in indices: - if state == 0: - if index is not None: - state = 1 - elif state == 1: - if index is None: - state = 2 - else: - if index is not None: - break - else: - has_contiguous_subspace = True - - # transposeToFront - # This is the logic that causes the newly inserted dimensions to show up - # at the beginning of the tensor, if they're not contiguous - if not has_contiguous_subspace: - dims = [] - transposed_indices = [] - for i, index in enumerate(indices): - if index is not None: - dims.append(i) - transposed_indices.append(index) - for i, index in enumerate(indices): - if index is None: - dims.append(i) - transposed_indices.append(index) - self = self.permute(dims) - indices = transposed_indices - - # AdvancedIndex::AdvancedIndex - # Now we can assume the indices have contiguous subspace - # This is simplified from AdvancedIndex which goes to more effort - # to put the input and indices in a form so that TensorIterator can - # take them. If we write a ref for this, probably that logic should - # get implemented - before_shape: List[int] = [] - after_shape: List[int] = [] - replacement_shape: List[int] = [] - for dim, index in enumerate(indices): - if index is None: - if replacement_shape: - after_shape.append(self.shape[dim]) + # hasContiguousSubspace + # true if all non-null tensors are adjacent + # See: + # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency + state = 0 + has_contiguous_subspace = False + for index in indices: + if state == 0: + if index is not None: + state = 1 + elif state == 1: + if index is None: + state = 2 else: - before_shape.append(self.shape[dim]) + if index is not None: + break else: - replacement_shape = list(index.shape) - return self.new_empty(before_shape + replacement_shape + after_shape) + has_contiguous_subspace = True + # transposeToFront + # This is the logic that causes the newly inserted dimensions to show up + # at the beginning of the tensor, if they're not contiguous + if not has_contiguous_subspace: + dims = [] + transposed_indices = [] + for i, index in enumerate(indices): + if index is not None: + dims.append(i) + transposed_indices.append(index) + for i, index in enumerate(indices): + if index is None: + dims.append(i) + transposed_indices.append(index) + self = self.permute(dims) + indices = transposed_indices -# ============================== Embedding ========================================= -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp -@register_meta(aten.embedding_dense_backward.default) -def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, - scale_grad_by_freq): - return new((num_weights, grad_output.size(-1)), - dtype=grad_output.dtype, - device=grad_output.device, - layout=grad_output.layout) + # AdvancedIndex::AdvancedIndex + # Now we can assume the indices have contiguous subspace + # This is simplified from AdvancedIndex which goes to more effort + # to put the input and indices in a form so that TensorIterator can + # take them. If we write a ref for this, probably that logic should + # get implemented + before_shape: List[int] = [] + after_shape: List[int] = [] + replacement_shape: List[int] = [] + for dim, index in enumerate(indices): + if index is None: + if replacement_shape: + after_shape.append(self.shape[dim]) + else: + before_shape.append(self.shape[dim]) + else: + replacement_shape = list(index.shape) + return self.new_empty(before_shape + replacement_shape + after_shape) + # ============================== Embedding ========================================= + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp + @register_meta(aten.embedding_dense_backward.default) + def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, + scale_grad_by_freq): + return new((num_weights, grad_output.size(-1)), + dtype=grad_output.dtype, + device=grad_output.device, + layout=grad_output.layout) -# ============================== Dropout =========================================== -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp -@register_meta(aten.native_dropout.default) -def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): - # notice that mask is bool - return new_like(input), new_like(input, dtype=torch.bool) # (output, mask) + # ============================== Dropout =========================================== + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp + @register_meta(aten.native_dropout.default) + def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): + # notice that mask is bool + return new_like(input), new_like(input, dtype=torch.bool) # (output, mask) - -# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp -@register_meta(aten.native_dropout_backward.default) -def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): - return new_like(grad) # (grad_in) + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp + @register_meta(aten.native_dropout_backward.default) + def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): + return new_like(grad) # (grad_in) diff --git a/colossalai/_analyzer/_subclasses/_monkey_patch.py b/colossalai/_analyzer/_subclasses/_monkey_patch.py index 1c7b972ab..7c1c3d3d8 100644 --- a/colossalai/_analyzer/_subclasses/_monkey_patch.py +++ b/colossalai/_analyzer/_subclasses/_monkey_patch.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist +from packaging import version aten = torch.ops.aten @@ -49,40 +50,45 @@ _DistCommMethod = [ "scatter", ] -# TODO: dive deep here -# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp -_AliasATen = [ - aten.detach.default, - aten.detach_.default, - aten.t.default, - aten.transpose.int, - aten.view.default, - aten._unsafe_view.default, - aten._reshape_alias.default, -] +if version.parse(torch.__version__) >= version.parse('1.12.0'): + # TODO: dive deep here + # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp + _AliasATen = [ + aten.detach.default, + aten.detach_.default, + aten.t.default, + aten.transpose.int, + aten.view.default, + aten._unsafe_view.default, + aten._reshape_alias.default, + ] -_InplaceATen = [ - aten.add_.Tensor, - aten.add_.Scalar, - aten.sub_.Tensor, - aten.sub_.Scalar, - aten.mul_.Tensor, - aten.mul_.Scalar, - aten.div_.Tensor, - aten.div_.Scalar, - aten.pow_.Tensor, - aten.pow_.Scalar, -] + _InplaceATen = [ + aten.add_.Tensor, + aten.add_.Scalar, + aten.sub_.Tensor, + aten.sub_.Scalar, + aten.mul_.Tensor, + aten.mul_.Scalar, + aten.div_.Tensor, + aten.div_.Scalar, + aten.pow_.Tensor, + aten.pow_.Scalar, + ] -# use `MaybeInplace` because they call ``as_strided()`` or ``slice()`` -_MaybeInplaceATen = [ - aten.diagonal.default, - aten.expand.default, - aten.select.int, - aten.slice.Tensor, - aten.split.Tensor, - aten.squeeze.default, - aten.permute.default, - aten.unsqueeze.default, - aten.as_strided.default, -] + # use `MaybeInplace` because they call ``as_strided()`` or ``slice()`` + _MaybeInplaceATen = [ + aten.diagonal.default, + aten.expand.default, + aten.select.int, + aten.slice.Tensor, + aten.split.Tensor, + aten.squeeze.default, + aten.permute.default, + aten.unsqueeze.default, + aten.as_strided.default, + ] +else: + _AliasATen = [] + _InplaceATen = [] + _MaybeInplaceATen = [] diff --git a/colossalai/_analyzer/_subclasses/flop_tensor.py b/colossalai/_analyzer/_subclasses/flop_tensor.py index ab9355146..dd35b00b3 100644 --- a/colossalai/_analyzer/_subclasses/flop_tensor.py +++ b/colossalai/_analyzer/_subclasses/flop_tensor.py @@ -11,6 +11,7 @@ from numbers import Number from typing import Any, Callable, List, Optional, Union import torch +from packaging import version from torch.utils._pytree import tree_map from .meta_tensor import MetaTensor @@ -403,134 +404,139 @@ def zero_flop_jit(*args): return 0 -flop_mapping = { +if version.parse(torch.__version__) >= version.parse('1.12.0'): + flop_mapping = { # gemm - aten.mm.default: matmul_flop_jit, - aten.matmul.default: matmul_flop_jit, - aten.addmm.default: addmm_flop_jit, - aten.bmm.default: bmm_flop_jit, + aten.mm.default: matmul_flop_jit, + aten.matmul.default: matmul_flop_jit, + aten.addmm.default: addmm_flop_jit, + aten.bmm.default: bmm_flop_jit, # convolution - aten.convolution.default: conv_flop_jit, - aten._convolution.default: conv_flop_jit, - aten.convolution_backward.default: conv_backward_flop_jit, + aten.convolution.default: conv_flop_jit, + aten._convolution.default: conv_flop_jit, + aten.convolution_backward.default: conv_backward_flop_jit, # normalization - aten.native_batch_norm.default: batchnorm_flop_jit, - aten.native_batch_norm_backward.default: batchnorm_flop_jit, - aten.cudnn_batch_norm.default: batchnorm_flop_jit, - aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), - aten.native_layer_norm.default: norm_flop_counter(2, 0), - aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), + aten.native_batch_norm.default: batchnorm_flop_jit, + aten.native_batch_norm_backward.default: batchnorm_flop_jit, + aten.cudnn_batch_norm.default: batchnorm_flop_jit, + aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), + aten.native_layer_norm.default: norm_flop_counter(2, 0), + aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), # pooling - aten.avg_pool1d.default: ewise_flop_counter(1, 0), - aten.avg_pool2d.default: ewise_flop_counter(1, 0), - aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1), - aten.avg_pool3d.default: ewise_flop_counter(1, 0), - aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1), - aten.max_pool1d.default: ewise_flop_counter(1, 0), - aten.max_pool2d.default: ewise_flop_counter(1, 0), - aten.max_pool3d.default: ewise_flop_counter(1, 0), - aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0), - aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0), - aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1), - aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0), - aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1), - aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0), - aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1), - aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0), - aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1), - aten.embedding_dense_backward.default: ewise_flop_counter(0, 1), - aten.embedding.default: ewise_flop_counter(1, 0), -} + aten.avg_pool1d.default: ewise_flop_counter(1, 0), + aten.avg_pool2d.default: ewise_flop_counter(1, 0), + aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1), + aten.avg_pool3d.default: ewise_flop_counter(1, 0), + aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1), + aten.max_pool1d.default: ewise_flop_counter(1, 0), + aten.max_pool2d.default: ewise_flop_counter(1, 0), + aten.max_pool3d.default: ewise_flop_counter(1, 0), + aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0), + aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0), + aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1), + aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0), + aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1), + aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0), + aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1), + aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0), + aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1), + aten.embedding_dense_backward.default: ewise_flop_counter(0, 1), + aten.embedding.default: ewise_flop_counter(1, 0), + } -ewise_flop_aten = [ + ewise_flop_aten = [ # basic op - aten.add.Tensor, - aten.add_.Tensor, - aten.div.Tensor, - aten.div_.Tensor, - aten.div.Scalar, - aten.div_.Scalar, - aten.mul.Tensor, - aten.mul.Scalar, - aten.mul_.Tensor, - aten.neg.default, - aten.pow.Tensor_Scalar, - aten.rsub.Scalar, - aten.sum.default, - aten.sum.dim_IntList, - aten.mean.dim, + aten.add.Tensor, + aten.add_.Tensor, + aten.div.Tensor, + aten.div_.Tensor, + aten.div.Scalar, + aten.div_.Scalar, + aten.mul.Tensor, + aten.mul.Scalar, + aten.mul_.Tensor, + aten.neg.default, + aten.pow.Tensor_Scalar, + aten.rsub.Scalar, + aten.sum.default, + aten.sum.dim_IntList, + aten.mean.dim, # activation op - aten.hardswish.default, - aten.hardswish_.default, - aten.hardswish_backward.default, - aten.hardtanh.default, - aten.hardtanh_.default, - aten.hardtanh_backward.default, - aten.hardsigmoid_backward.default, - aten.hardsigmoid.default, - aten.gelu.default, - aten.gelu_backward.default, - aten.silu.default, - aten.silu_.default, - aten.silu_backward.default, - aten.sigmoid.default, - aten.sigmoid_backward.default, - aten._softmax.default, - aten._softmax_backward_data.default, - aten.relu_.default, - aten.relu.default, - aten.tanh.default, - aten.tanh_backward.default, - aten.threshold_backward.default, + aten.hardswish.default, + aten.hardswish_.default, + aten.hardswish_backward.default, + aten.hardtanh.default, + aten.hardtanh_.default, + aten.hardtanh_backward.default, + aten.hardsigmoid_backward.default, + aten.hardsigmoid.default, + aten.gelu.default, + aten.gelu_backward.default, + aten.silu.default, + aten.silu_.default, + aten.silu_backward.default, + aten.sigmoid.default, + aten.sigmoid_backward.default, + aten._softmax.default, + aten._softmax_backward_data.default, + aten.relu_.default, + aten.relu.default, + aten.tanh.default, + aten.tanh_backward.default, + aten.threshold_backward.default, # dropout - aten.native_dropout.default, - aten.native_dropout_backward.default, + aten.native_dropout.default, + aten.native_dropout_backward.default, # distribution - aten.bernoulli_.float, + aten.bernoulli_.float, # where - aten.where.self, -] -for op in ewise_flop_aten: - flop_mapping[op] = ewise_flop_counter(1, 0) + aten.where.self, + ] + for op in ewise_flop_aten: + flop_mapping[op] = ewise_flop_counter(1, 0) -# fix-me: this will be removed in future -zero_flop_aten = [ - aten.as_strided.default, - aten.as_strided_.default, - aten.cat.default, - aten.clone.default, - aten.copy_.default, - aten.detach.default, - aten.expand.default, - aten.empty_like.default, - aten.new_empty.default, - aten.new_empty_strided.default, - aten.ones_like.default, - aten._reshape_alias.default, - aten.select.int, - aten.select_backward.default, - aten.squeeze.dim, - aten.slice.Tensor, - aten.slice_backward.default, - aten.split.Tensor, - aten.permute.default, - aten.t.default, - aten.transpose.int, - aten._to_copy.default, - aten.unsqueeze.default, - aten.unbind.int, - aten._unsafe_view.default, - aten.view.default, - aten.zero_.default, - aten.zeros_like.default, -] + # fix-me: this will be removed in future + zero_flop_aten = [ + aten.as_strided.default, + aten.as_strided_.default, + aten.cat.default, + aten.clone.default, + aten.copy_.default, + aten.detach.default, + aten.expand.default, + aten.empty_like.default, + aten.new_empty.default, + aten.new_empty_strided.default, + aten.ones_like.default, + aten._reshape_alias.default, + aten.select.int, + aten.select_backward.default, + aten.squeeze.dim, + aten.slice.Tensor, + aten.slice_backward.default, + aten.split.Tensor, + aten.permute.default, + aten.t.default, + aten.transpose.int, + aten._to_copy.default, + aten.unsqueeze.default, + aten.unbind.int, + aten._unsafe_view.default, + aten.view.default, + aten.zero_.default, + aten.zeros_like.default, + ] -for op in zero_flop_aten: - flop_mapping[op] = zero_flop_jit + for op in zero_flop_aten: + flop_mapping[op] = zero_flop_jit +else: + flop_mapping = {} + elementwise_flop_aten = {} + zero_flop_aten = {} diff --git a/colossalai/_analyzer/fx/__init__.py b/colossalai/_analyzer/fx/__init__.py index 2e857b1b0..aa01de0bb 100644 --- a/colossalai/_analyzer/fx/__init__.py +++ b/colossalai/_analyzer/fx/__init__.py @@ -1,4 +1,3 @@ -from .bias_addition import * from .node_util import MetaInfo from .symbolic_profile import symbolic_profile -from .symbolic_trace import symbolic_trace +from .tracer.symbolic_trace import symbolic_trace diff --git a/colossalai/_analyzer/fx/graph_module.py b/colossalai/_analyzer/fx/graph_module.py index 779b42eba..1fdedd758 100644 --- a/colossalai/_analyzer/fx/graph_module.py +++ b/colossalai/_analyzer/fx/graph_module.py @@ -1,4 +1,7 @@ +import linecache import os +import sys +import traceback import warnings from pathlib import Path from typing import Any, Dict, Optional, Union @@ -6,11 +9,74 @@ from typing import Any, Dict, Optional, Union import torch import torch.fx import torch.nn as nn -from torch.fx.graph import PythonCode, _PyTreeCodeGen -from torch.fx.graph_module import _exec_with_source, _forward_from_src, _WrappedCall +from torch.fx.graph import PythonCode + +try: + from torch.fx.graph import _PyTreeCodeGen + SUPPORT_PT_CODEGEN = True +except ImportError: + SUPPORT_PT_CODEGEN = False + +from torch.fx.graph_module import _exec_with_source, _forward_from_src from torch.nn.modules.module import _addindent +# This is a copy of torch.fx.graph_module._WrappedCall. +# It should be removed when we stop supporting torch < 1.12.0. +class _WrappedCall: + + def __init__(self, cls, cls_call): + self.cls = cls + self.cls_call = cls_call + + # Previously, if an error occurred when valid + # symbolically-traced code was run with an invalid input, the + # user would see the source of the error as coming from + # `File "`, where N is some number. We use + # this function to generate a more informative error message. We + # return the traceback itself, a message explaining that the + # error occurred in a traced Module's generated forward + # function, and five lines of context surrounding the faulty + # line + @staticmethod + def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: + # auxiliary variables (for readability) + err_lineno = frame_summary.lineno + assert err_lineno is not None + line = frame_summary.line + assert line is not None + err_line_len = len(line) + all_src_lines = linecache.getlines(frame_summary.filename) + + # constituent substrings of the error message + tb_repr = traceback.format_exc() + custom_msg = ("Call using an FX-traced Module, " + f"line {err_lineno} of the traced Module's " + "generated forward function:") + before_err = "".join(all_src_lines[err_lineno - 2:err_lineno]) + marker = "~" * err_line_len + "~~~ <--- HERE" + err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2]) + + # joined message + return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) + + def __call__(self, obj, *args, **kwargs): + try: + if self.cls_call is not None: + return self.cls_call(obj, *args, **kwargs) + else: + return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] + except Exception as e: + assert e.__traceback__ + topmost_framesummary: traceback.FrameSummary = \ + traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type] + if "eval_with_key" in topmost_framesummary.filename: + print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr) + raise e.with_traceback(None) + else: + raise e + + class ColoGraphModule(torch.fx.GraphModule): """ ColoGraphGraphModule is an nn.Module generated from an fx.Graph. @@ -65,7 +131,7 @@ class ColoGraphModule(torch.fx.GraphModule): called after editing the contained ``graph``, otherwise the generated code of this ``GraphModule`` will be out of date. """ - if isinstance(self._graph._codegen, _PyTreeCodeGen): + if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec python_code = self._graph.python_code(root_module='self') diff --git a/colossalai/_analyzer/fx/node_util.py b/colossalai/_analyzer/fx/node_util.py index d06fa8b93..8c8956d8e 100644 --- a/colossalai/_analyzer/fx/node_util.py +++ b/colossalai/_analyzer/fx/node_util.py @@ -20,7 +20,7 @@ def union(a, b): return {**a, **b} -def compute_size_in_bytes(elem: torch.Tensor | Dict | List | Tuple | int) -> int: +def compute_size_in_bytes(elem: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: """Compute the size of a tensor or a collection of tensors in bytes. Args: @@ -195,8 +195,8 @@ class MetaInfo: s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}' if self.output_size: s += f'\n\thas output activation of size {_format_memory(self.output_size)}' - if self.total_size: - s += f'\n\thas total activation of size {_format_memory(self.total_size)}' + # if self.total_size: + # s += f'\n\thas total activation of size {_format_memory(self.total_size)}' if self.temp_size: s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}' if self.backward_size: diff --git a/colossalai/_analyzer/fx/passes/shape_prop.py b/colossalai/_analyzer/fx/passes/shape_prop.py index 3691497ed..ab3e1a4d6 100644 --- a/colossalai/_analyzer/fx/passes/shape_prop.py +++ b/colossalai/_analyzer/fx/passes/shape_prop.py @@ -111,7 +111,24 @@ class ShapeProp(torch.fx.Interpreter): with self.global_hook: r = getattr(self, n.op)(n.target, args, kwargs) - unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem + def unwrap_fn(elem): + + def _convert_meta(t: torch.Tensor): + if t.device == 'meta': + return t + else: + return t.to('meta') + + if isinstance(elem, MetaTensor): + return _convert_meta(elem._tensor) + + elif isinstance(elem, torch.Tensor): + return _convert_meta(elem) + + else: + return elem + + # unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter) n_info = MetaInfo(n) n_info.outputs = _normalize_tuple(r) diff --git a/colossalai/_analyzer/fx/tracer/__init__.py b/colossalai/_analyzer/fx/tracer/__init__.py new file mode 100644 index 000000000..6b1b2256a --- /dev/null +++ b/colossalai/_analyzer/fx/tracer/__init__.py @@ -0,0 +1,2 @@ +from .bias_addition import * +from .custom_leaf_module import * diff --git a/colossalai/_analyzer/fx/bias_addition.py b/colossalai/_analyzer/fx/tracer/bias_addition.py similarity index 98% rename from colossalai/_analyzer/fx/bias_addition.py rename to colossalai/_analyzer/fx/tracer/bias_addition.py index 5359752d4..1e75b47ca 100644 --- a/colossalai/_analyzer/fx/bias_addition.py +++ b/colossalai/_analyzer/fx/tracer/bias_addition.py @@ -4,11 +4,10 @@ graph construction to deal with the compatibility between bias-addition and all- """ import torch -import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.utils import _pair, _single, _triple -from .symbolic_trace import register_tracer_impl +from .tracer import register_tracer_impl __all__ = [] diff --git a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py new file mode 100644 index 000000000..112c7c963 --- /dev/null +++ b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py @@ -0,0 +1,29 @@ +import torch + +from .tracer import register_leaf_module, register_leaf_module_impl + +try: + import apex + register_leaf_module(apex.normalization.FusedLayerNorm) + register_leaf_module(apex.normalization.FusedRMSNorm) + register_leaf_module(apex.normalization.MixedFusedLayerNorm) + register_leaf_module(apex.normalization.MixedFusedRMSNorm) + + @register_leaf_module_impl(apex.normalization.FusedLayerNorm) + @register_leaf_module_impl(apex.normalization.FusedRMSNorm) + @register_leaf_module_impl(apex.normalization.MixedFusedLayerNorm) + @register_leaf_module_impl(apex.normalization.MixedFusedRMSNorm) + def torch_nn_normalize(self, input: torch.Tensor): + # check shape + if isinstance(self, torch.nn.BatchNorm1d): + assert input.dim() in [2, 3] + elif isinstance(self, torch.nn.BatchNorm2d): + assert input.dim() == 4 + elif isinstance(self, torch.nn.BatchNorm3d): + assert input.dim() == 5 + + # normalization maintain the same shape as the input + return input.clone() + +except (ImportError, AttributeError): + pass diff --git a/colossalai/_analyzer/fx/tracer/proxy.py b/colossalai/_analyzer/fx/tracer/proxy.py new file mode 100644 index 000000000..ce379efdc --- /dev/null +++ b/colossalai/_analyzer/fx/tracer/proxy.py @@ -0,0 +1,112 @@ +import operator +from typing import Any, Callable, Dict, Optional, Set, Union + +import torch +import torch.nn as nn +from torch.fx import Graph, Node, Proxy, Tracer +from torch.fx.graph import _Namespace +from torch.utils._pytree import tree_map + +from colossalai._analyzer._subclasses import MetaTensor + +Target = Union[Callable[..., Any], str] + + +class ColoProxy(Proxy): + _func_dispatch: Dict[Target, Callable[..., Any]] = {} + + def __init__(self, *args, data=None, **kwargs): + super().__init__(*args, **kwargs) + self._meta_data = data + + @property + def meta_data(self): + return self._meta_data + + @meta_data.setter + def meta_data(self, args): + wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x + self._meta_data = tree_map(wrap_fn, args) + + @classmethod + def __torch_function__(cls, orig_method, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + if orig_method in cls._func_dispatch: + impl = cls._func_dispatch.pop(orig_method) # avoid recursion + proxy = impl(*args, **kwargs) + cls._func_dispatch[orig_method] = impl + return proxy + else: + proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs)) + unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p + if proxy.meta_data is None: + proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) + return proxy + + @classmethod + def from_torch_proxy(cls, proxy: Proxy): + return cls(proxy.node, proxy.tracer) + + def __repr__(self): + return f"ColoProxy({self.node.name}, meta_data={self.meta_data})" + + def __len__(self): + return len(self.meta_data) + + def __int__(self): + return int(self.meta_data) + + def __index__(self): + try: + return int(self.meta_data) + except: + return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__() + + def __float__(self): + return float(self.meta_data) + + def __bool__(self): + return self.meta_data + + def __getattr__(self, k): + return ColoAttribute(self, k, getattr(self._meta_data, k, None)) + + def __setitem__(self, key, value): + proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) + proxy.meta_data = self._meta_data + return proxy + + def __contains__(self, key): + if self.node.op == "placeholder": + # this is used to handle like + # if x in kwargs + # we don't handle this case for now + return False + return super().__contains__(key) + + def __isinstancecheck__(self, type): + return isinstance(self.meta_data, type) + + +class ColoAttribute(ColoProxy): + + def __init__(self, root, attr: str, data=None): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._meta_data = data + self._node: Optional[Node] = None + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + + def __repr__(self): + return f"ColoAttribute({self.node.name}, attr={self.attr})" diff --git a/colossalai/_analyzer/fx/tracer/symbolic_trace.py b/colossalai/_analyzer/fx/tracer/symbolic_trace.py new file mode 100644 index 000000000..2018863f6 --- /dev/null +++ b/colossalai/_analyzer/fx/tracer/symbolic_trace.py @@ -0,0 +1,157 @@ +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union + +import torch +from torch.fx import Tracer +from torch.utils._pytree import tree_map + +from colossalai._analyzer._subclasses import MetaTensor + +try: + from ..codegen import ActivationCheckpointCodeGen + SUPPORT_ACTIVATION = True +except: + SUPPORT_ACTIVATION = False +from ..graph_module import ColoGraphModule +from .tracer import ColoTracer + + +def _default_device(): + return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + + +def _current_device(module: torch.nn.Module): + try: + return next(module.parameters()).device + except: + return _default_device() + + +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + meta_args: Optional[Dict[str, Any]] = None, + trace_act_ckpt: bool = False, + bias_addition_split: bool = False, +) -> ColoGraphModule: + """ + Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo`` + attached to the ``Node``s. + + Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module + (https://github.com/pytorch/examples/blob/main/fx/module_tracer.py). + + This tracer is able to trace basic control flow and for loops. + + It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``. + (See ./bias_addition.py for more details). + + Examples: + 1. Tracing a ``torch.nn.Module`` with control flow. + + .. code-block:: python + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + if x.size(0) > 1: + x = x.sum(dim=0) + return self.linear(x) + + traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}) + + # traced code like: + # def forward(self, x): + # linear_1 = self.linear(x) + # return linear_1 + + traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)}) + + # traced code like: + # def forward(self, x): + # sum = x.sum(dim=0); x = None + # linear = self.linear(sum); sum = None + # return linear + + 2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``. + + .. code-block:: python + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + def custom_forward(x): + return self.linear(x) + return torch.utils.checkpoint.checkpoint(custom_forward, x) + + traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True) + + # traced code like: + # def checkpoint_0(self, x): + # linear = self.linear(x); x = None + # return linear + # + # def forward(self, x): + # linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None + # return linear + + 3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``. + + .. code-block:: python + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2, bias=True) + + def forward(self, x): + return self.linear(x) + + traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True) + + # traced code like: + # def forward(self, x): + # linear_bias = self.linear.bias + # linear_weight = self.linear.weight + # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None + # add = linear + linear_bias; linear = linear_bias = None + # return add + + Args: + root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced. + concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``. + Defaults to {}. + meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used + for tracing control flow. Defaults to {}. + trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``. + Defaults to False. + bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False. + + Returns: + ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``. + + Remarks: + This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered + any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub + repo. We welcome any feedback and contributions to enhance the extensibility of + Colossal-AI. + """ + if meta_args: + device, orig_device = _default_device(), _current_device(root) + wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem + graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, + bias_addition_split=bias_addition_split).trace(root.to(device), + concrete_args=concrete_args, + meta_args=tree_map(wrap_fn, meta_args)) + if trace_act_ckpt and SUPPORT_ACTIVATION: + graph.set_codegen(ActivationCheckpointCodeGen()) + root.to(orig_device) + else: + graph = Tracer().trace(root, concrete_args=concrete_args) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + return ColoGraphModule(root, graph, name) diff --git a/colossalai/_analyzer/fx/symbolic_trace.py b/colossalai/_analyzer/fx/tracer/tracer.py similarity index 53% rename from colossalai/_analyzer/fx/symbolic_trace.py rename to colossalai/_analyzer/fx/tracer/tracer.py index 5d858c87a..1a247449f 100644 --- a/colossalai/_analyzer/fx/symbolic_trace.py +++ b/colossalai/_analyzer/fx/tracer/tracer.py @@ -1,28 +1,19 @@ import functools import inspect -import operator from contextlib import contextmanager -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union import torch import torch.nn as nn from torch.fx import Graph, Node, Proxy, Tracer -from torch.fx.graph import _Namespace from torch.utils._pytree import tree_map -from colossalai._analyzer._subclasses import MetaTensor, _TensorPropertyMethod, _TorchFactoryMethod +from colossalai._analyzer._subclasses import _TensorPropertyMethod, _TorchFactoryMethod -from .codegen import ActivationCheckpointCodeGen -from .graph_module import ColoGraphModule -from .node_util import MetaInfo +from ..node_util import MetaInfo +from .proxy import ColoProxy Target = Union[Callable[..., Any], str] -Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types - List[Any], # actually Argument - Dict[str, Any], # actually Argument - slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing - 'Node',]] -zeros = torch.zeros def _truncate_suffix(s: str): @@ -32,17 +23,6 @@ def _truncate_suffix(s: str): return re.sub(r'_\d+$', '', s) -def _default_device(): - return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') - - -def _current_device(module): - try: - return next(module.parameters()).device - except: - return _default_device() - - def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'): def wrapper(impl): @@ -70,149 +50,6 @@ def register_non_leaf_module(module: nn.Module): ColoTracer._custom_non_leaf_module.add(module) -class ColoProxy(Proxy): - _func_dispatch: Dict[Target, Callable[..., Any]] = {} - - def __init__(self, *args, data=None, **kwargs): - super().__init__(*args, **kwargs) - self._meta_data = data - - @property - def meta_data(self): - return self._meta_data - - @meta_data.setter - def meta_data(self, args): - wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x - self._meta_data = tree_map(wrap_fn, args) - - @classmethod - def __torch_function__(cls, orig_method, types, args=(), kwargs=None): - kwargs = {} if kwargs is None else kwargs - if orig_method in cls._func_dispatch: - impl = cls._func_dispatch.pop(orig_method) # avoid recursion - proxy = impl(*args, **kwargs) - cls._func_dispatch[orig_method] = impl - return proxy - else: - proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs)) - unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p - if proxy.meta_data is None: - proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) - return proxy - - @classmethod - def from_torch_proxy(cls, proxy: Proxy): - return cls(proxy.node, proxy.tracer) - - def __repr__(self): - return f"ColoProxy({self.node.name}, meta_data={self.meta_data})" - - def __len__(self): - return len(self.meta_data) - - def __int__(self): - return int(self.meta_data) - - def __index__(self): - try: - return int(self.meta_data) - except: - return zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__() - - def __float__(self): - return float(self.meta_data) - - def __bool__(self): - return self.meta_data - - def __getattr__(self, k): - return ColoAttribute(self, k, getattr(self._meta_data, k, None)) - - def __setitem__(self, key, value): - proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) - proxy.meta_data = self._meta_data - return proxy - - def __contains__(self, key): - if self.node.op == "placeholder": - # this is used to handle like - # if x in kwargs - # we don't handle this case for now - return False - return super().__contains__(key) - - def __isinstancecheck__(self, type): - return isinstance(self.meta_data, type) - - def size(self, dim=None): - if self._meta_data is None: - return self._meta_data.size(*[dim] if dim else []) - return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {}) - - def dim(self): - if self._meta_data is not None: - return self._meta_data.dim() - return self.tracer.create_proxy('call_method', 'dim', (self,), {}) - - @property - def shape(self): - if self._meta_data is not None: - return self._meta_data.shape - return self.tracer.create_proxy('call_function', getattr, (self, 'shape'), {}) - - @property - def ndim(self): - if self._meta_data is not None: - return self._meta_data.ndim - return self.tracer.create_proxy('call_function', getattr, (self, 'ndim'), {}) - - @property - def device(self): - if self._meta_data is not None: - return self._meta_data.device - return self.tracer.create_proxy('call_function', getattr, (self, 'device'), {}) - - @property - def dtype(self): - if self._meta_data is not None: - return self._meta_data.dtype - return self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {}) - - def to(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs}) - - def cpu(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs}) - - def cuda(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs}) - - -class ColoAttribute(ColoProxy): - - def __init__(self, root, attr: str, data=None): - self.root = root - self.attr = attr - self.tracer = root.tracer - self._meta_data = data - self._node: Optional[Node] = None - - @property - def node(self): - # the node for attributes is added lazily, since most will just be method calls - # which do not rely on the getitem call - if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node - return self._node - - def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) - - def __repr__(self): - return f"ColoAttribute({self.node.name}, attr={self.attr})" - - class ColoTracer(Tracer): _custom_leaf_module: Set[Type[nn.Module]] = set() _custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {} @@ -249,7 +86,6 @@ class ColoTracer(Tracer): # we will enter the module and split the bias-addition ops if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None: return False - # user can specify which modules are leaf modules and which are not return (type(m) not in self._custom_non_leaf_module and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name))) @@ -306,9 +142,13 @@ class ColoTracer(Tracer): mod = self.root.get_submodule(target) self.disable_module_getattr = True try: - proxy.meta_data = self._custom_leaf_module_impl.get(type(mod), - mod.forward)(*tree_map(unwrap_fn, args), - **tree_map(unwrap_fn, kwargs)) + args = tree_map(unwrap_fn, args) + kwargs = tree_map(unwrap_fn, kwargs) + if type(mod) in self._custom_leaf_module: + target = self._custom_leaf_module_impl[type(mod)] + proxy.meta_data = target(mod, *args, **kwargs) + else: + proxy.meta_data = mod.forward(*args, **kwargs) finally: self.disable_module_getattr = False return proxy @@ -320,15 +160,21 @@ class ColoTracer(Tracer): def trace(self, root: torch.nn.Module, - concrete_args: Optional[Dict[str, torch.Tensor]] = {}, - meta_args: Optional[Dict[str, torch.Tensor]] = {}) -> Graph: + concrete_args: Optional[Dict[str, torch.Tensor]] = None, + meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: + + if meta_args is None: + meta_args = {} + + if concrete_args is None: + concrete_args = {} # check concrete and meta args have valid names sig = inspect.signature(root.forward) sig_names = set(sig.parameters.keys()) meta_arg_names = set(meta_args.keys()) concrete_arg_names = set(concrete_args.keys()) - + non_concrete_arg_names = sig_names - concrete_arg_names # update concrete args with default values for k, v in sig.parameters.items(): if k in sig_names - meta_arg_names and \ @@ -352,6 +198,34 @@ class ColoTracer(Tracer): self.graph = super().trace(root, concrete_args=concrete_args) self.mod_dir = '' self.graph.lint() + + for node in self.graph.nodes: + if node.op == "placeholder": + # Removing default values for inputs as the forward pass will fail with them. + if node.target in non_concrete_arg_names: + node.args = () + # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor]. + # It cannot infer on the attributes and methods the input should have, and fails. + node.type = torch.Tensor + # It is a concrete arg so it is not used and should be removed. + else: + if hasattr(torch.fx._symbolic_trace, "_assert_is_none"): + # Newer versions of torch.fx emit an assert statement + # for concrete arguments; delete those before we delete + # the concrete arg. + to_delete = [] + for user in node.users: + if user.target == torch.fx._symbolic_trace._assert_is_none: + to_delete.append(user) + for user in to_delete: + self.graph.erase_node(user) + + self.graph.erase_node(node) + + # TODO: solves GraphModule creation. + # Without this, return type annotation "Tuple" is causing code execution failure. + if node.op == "output": + node.type = None return self.graph @contextmanager @@ -454,7 +328,7 @@ class ColoTracer(Tracer): if node.op == "output": node.type = None self.graph.lint() - + def getattr(self, attr, attr_val, parameter_proxy_cache): return self._module_getattr(attr, attr_val, parameter_proxy_cache) @@ -487,134 +361,3 @@ class ColoTracer(Tracer): return maybe_parameter_proxy return attr_val - - -def symbolic_trace( - root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = {}, - meta_args: Optional[Dict[str, Any]] = {}, - trace_act_ckpt: bool = False, - bias_addition_split: bool = False, -) -> ColoGraphModule: - """ - Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo`` - attached to the ``Node``s. - - Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module - (https://github.com/pytorch/examples/blob/main/fx/module_tracer.py). - - This tracer is able to trace basic control flow and for loops. - - It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``. - (See ./bias_addition.py for more details). - - Examples: - 1. Tracing a ``torch.nn.Module`` with control flow. - - .. code-block:: python - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 2) - - def forward(self, x): - if x.size(0) > 1: - x = x.sum(dim=0) - return self.linear(x) - - traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}) - - # traced code like: - # def forward(self, x): - # linear_1 = self.linear(x) - # return linear_1 - - traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)}) - - # traced code like: - # def forward(self, x): - # sum = x.sum(dim=0); x = None - # linear = self.linear(sum); sum = None - # return linear - - 2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``. - - .. code-block:: python - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 2) - - def forward(self, x): - def custom_forward(x): - return self.linear(x) - return torch.utils.checkpoint.checkpoint(custom_forward, x) - - traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True) - - # traced code like: - # def checkpoint_0(self, x): - # linear = self.linear(x); x = None - # return linear - # - # def forward(self, x): - # linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None - # return linear - - 3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``. - - .. code-block:: python - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 2, bias=True) - - def forward(self, x): - return self.linear(x) - - traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True) - - # traced code like: - # def forward(self, x): - # linear_bias = self.linear.bias - # linear_weight = self.linear.weight - # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None - # add = linear + linear_bias; linear = linear_bias = None - # return add - - Args: - root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced. - concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``. - Defaults to {}. - meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used - for tracing control flow. Defaults to {}. - trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``. - Defaults to False. - bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False. - - Returns: - ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``. - - Remarks: - This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered - any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub - repo. We welcome any feedback and contributions to enhance the extensibility of - Colossal-AI. - """ - if meta_args: - device, orig_device = _default_device(), _current_device(root) - wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem - graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, - bias_addition_split=bias_addition_split).trace(root.to(device), - concrete_args=concrete_args, - meta_args=tree_map(wrap_fn, meta_args)) - if trace_act_ckpt: - graph.set_codegen(ActivationCheckpointCodeGen()) - root.to(orig_device) - else: - graph = Tracer().trace(root, concrete_args=concrete_args) - name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ - return ColoGraphModule(root, graph, name) diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 710038ffa..466a2a558 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,5 +1,4 @@ from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers - from .registry import model_zoo __all__ = ['model_zoo'] diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 2a100c981..5ed4fbe70 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -17,6 +17,14 @@ def data_gen(): return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) +def seq_classification_data_gen(): + # batch sizes should be 1 if no padding token is defined. + input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) + return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + output_transform_fn = lambda x: x config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4) @@ -44,6 +52,6 @@ model_zoo.register(name='transformers_gpt_for_token_classification', model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_for_sequence_classification', model_fn=lambda: transformers.GPT2ForSequenceClassification(config), - data_gen_fn=data_gen, + data_gen_fn=seq_classification_data_gen, output_transform_fn=output_transform_fn, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_analyzer/test_fx/test_bias_addition.py b/tests/test_analyzer/test_fx/test_bias_addition.py index 5c9ec7cc3..044a464be 100644 --- a/tests/test_analyzer/test_fx/test_bias_addition.py +++ b/tests/test_analyzer/test_fx/test_bias_addition.py @@ -1,5 +1,6 @@ import pytest import torch +from packaging import version from torch.utils.checkpoint import checkpoint try: @@ -73,7 +74,7 @@ class AddmmModel(torch.nn.Module): return x -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("bias_addition_split", [True, False]) @pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)]) diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index 6d93fe040..7a4bf131a 100644 --- a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -3,7 +3,8 @@ from numpy import isin from torch.fx import GraphModule from torch.utils._pytree import tree_flatten -from colossalai.fx import symbolic_trace +# from colossalai.fx import symbolic_trace +from colossalai._analyzer.fx import symbolic_trace def trace_model_and_compare_output(model, data_gen): diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index b1c9c211a..31ba2290e 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -1,4 +1,7 @@ +import pytest +import torch from hf_tracer_utils import trace_model_and_compare_output +from packaging import version from tests.kit.model_zoo import model_zoo @@ -6,6 +9,7 @@ BATCH_SIZE = 2 SEQ_LENGTH = 16 +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') def test_albert(): sub_registry = model_zoo.get_sub_registry('transformers_albert') diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index 1bf4947c3..8db6817c6 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -1,8 +1,12 @@ +import pytest +import torch from hf_tracer_utils import trace_model_and_compare_output +from packaging import version from tests.kit.model_zoo import model_zoo +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') def test_bert(): sub_registry = model_zoo.get_sub_registry('transformers_bert') diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 67a3178fa..796c17e39 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -1,16 +1,24 @@ import pytest +import torch from hf_tracer_utils import trace_model_and_compare_output +from packaging import version from tests.kit.model_zoo import model_zoo -# TODO: remove this skip once we handle the latest gpt model -@pytest.mark.skip +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') def test_gpt(): sub_registry = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): model = model_fn() + + # TODO: support the following models + # 1. GPT2DoubleHeadsModel + # as they are not supported, let's skip them + if model.__class__.__name__ in ['GPT2DoubleHeadsModel']: + continue + trace_model_and_compare_output(model, data_gen_fn) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index 740f5a9f0..e7bfa6070 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -1,8 +1,12 @@ +import pytest +import torch from hf_tracer_utils import trace_model_and_compare_output +from packaging import version from tests.kit.model_zoo import model_zoo +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') def test_opt(): sub_registry = model_zoo.get_sub_registry('transformers_opt') diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 7073fd634..5f7e4f81c 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -1,8 +1,12 @@ +import pytest +import torch from hf_tracer_utils import trace_model_and_compare_output +from packaging import version from tests.kit.model_zoo import model_zoo +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') def test_t5(): sub_registry = model_zoo.get_sub_registry('transformers_t5') diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 31baa3e89..b175d8b10 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -1,8 +1,8 @@ import pytest -import timm.models as tm import torch +from packaging import version -from colossalai.fx import symbolic_trace +from colossalai._analyzer.fx import symbolic_trace from tests.kit.model_zoo import model_zoo @@ -42,6 +42,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') def test_timm_models(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py index bf6c7ae55..65f9f5149 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -1,20 +1,18 @@ -import re - +import pytest import torch +from packaging import version from torchaudio_utils import trace_and_compare from tests.kit.model_zoo import model_zoo +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') def test_torchaudio_models(): torch.backends.cudnn.deterministic = True sub_model_zoo = model_zoo.get_sub_registry('torchaudio') for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): - # FIXME(ver217): temporarily skip these models - if re.search(f'(conformer|emformer|tacotron|wav2vec2_base|hubert_base)', name): - continue model = model_fn() trace_and_compare(model, data_gen_fn, diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py index 18d86fc05..239f38680 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py @@ -1,6 +1,6 @@ import torch -from colossalai.fx import symbolic_trace +from colossalai._analyzer.fx import symbolic_trace def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False): diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index a4e847dbc..40f83d47a 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -1,7 +1,7 @@ import pytest import torch -from colossalai.fx import symbolic_trace +from colossalai._analyzer.fx import symbolic_trace from tests.kit.model_zoo import model_zoo BATCH = 2 diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index ac377ff1d..6d4b6ab81 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -1,7 +1,7 @@ import pytest import torch -from colossalai.fx import symbolic_trace +from colossalai._analyzer.fx import symbolic_trace from tests.kit.model_zoo import model_zoo BATCH = 2 diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index 455638818..8dbbf9f5a 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -1,6 +1,6 @@ import torch -from colossalai.fx import symbolic_trace +from colossalai._analyzer.fx import symbolic_trace from tests.kit.model_zoo import model_zoo