diff --git a/colossalai/_analyzer/_subclasses/flop_tensor.py b/colossalai/_analyzer/_subclasses/flop_tensor.py index dd35b00b3..59991dc50 100644 --- a/colossalai/_analyzer/_subclasses/flop_tensor.py +++ b/colossalai/_analyzer/_subclasses/flop_tensor.py @@ -235,7 +235,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: # Inputs contains the shapes of two matrices. input_shapes = [v.shape for v in inputs] assert len(input_shapes) == 2, input_shapes - assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes + + # There are three cases: 1) gemm, 2) gemv, 3) dot + if all(len(shape) == 2 for shape in input_shapes): + # gemm + assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes + elif all(len(shape) == 1 for shape in input_shapes): + # dot + assert input_shapes[0][0] == input_shapes[1][0], input_shapes + + # expand shape + input_shapes[0] = torch.Size([1, input_shapes[0][0]]) + input_shapes[1] = torch.Size([input_shapes[1][0], 1]) + else: + # gemv + if len(input_shapes[0]) == 1: + assert input_shapes[0][0] == input_shapes[1][-2], input_shapes + input_shapes.reverse() + else: + assert input_shapes[1][0] == input_shapes[0][-1], input_shapes + + # expand the shape of the vector to [batch size, 1] + input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1]) flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1] return flops diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py index 1117c0103..b768e5900 100644 --- a/colossalai/_analyzer/fx/codegen.py +++ b/colossalai/_analyzer/fx/codegen.py @@ -1,8 +1,12 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple import torch + +try: + from torch.fx.graph import CodeGen +except: + pass from torch.fx.graph import ( - CodeGen, PythonCode, _custom_builtins, _format_target, @@ -48,8 +52,8 @@ def _end_of_ckpt(node: Node, ckpt_level: int) -> bool: """ Check if the node could end the ckpt region at `ckpt_level` """ - if len(node.meta['info'].to_recompute) > ckpt_level: - return node.meta['info'].to_recompute[ckpt_level] is not None + if len(node.meta['info'].activation_checkpoint) > ckpt_level: + return node.meta['info'].activation_checkpoint[ckpt_level] is not None return True @@ -90,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0): current_region = None for idx, node in enumerate(node_list): - if len(node.meta['info'].to_recompute) > ckpt_level: - act_ckpt_label = node.meta['info'].to_recompute[ckpt_level] + if len(node.meta['info'].activation_checkpoint) > ckpt_level: + act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -152,12 +156,12 @@ def emit_ckpt_func(body, # label given by each layer, e.g. if you are currently at level (0, 1, 1) # the label will be '0_1_1' - label = "_".join([str(idx) for idx in node_list[0].meta['info'].to_recompute[:ckpt_level + 1]]) + label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]]) ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) ckpt_func.append(f'{ckpt_fn_def}\n') # if there is more level to fetch - if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].to_recompute), node_list)): + if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)): ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1) start_idx = [item[0] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions] @@ -215,7 +219,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ckpt_regions = _find_nested_ckpt_regions(nodes, 0) start_idx = [item[0] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions] - node_list = list(nodes) node_idx = 0 diff --git a/colossalai/_analyzer/fx/node_util.py b/colossalai/_analyzer/fx/node_util.py index 8c8956d8e..fbe8400a4 100644 --- a/colossalai/_analyzer/fx/node_util.py +++ b/colossalai/_analyzer/fx/node_util.py @@ -112,7 +112,7 @@ class MetaInfo: # should keep the same whenever manipulated # ============================= Invariant ================================== - to_recompute: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen + activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen to_offload: Optional[bool] = False sharding_spec: str = 'RR' diff --git a/colossalai/_analyzer/fx/passes/shape_prop.py b/colossalai/_analyzer/fx/passes/shape_prop.py index b3859e250..23e83013e 100644 --- a/colossalai/_analyzer/fx/passes/shape_prop.py +++ b/colossalai/_analyzer/fx/passes/shape_prop.py @@ -237,7 +237,14 @@ class ShapeProp(torch.fx.Interpreter): Returns: Any: The value returned from executing the Module """ - wrap_fn = lambda elem: MetaTensor(elem, device=device) + + # wrap_fn = lambda elem: MetaTensor(elem, device=device) + def wrap_fn(elem, device=device): + if isinstance(elem, torch.Tensor): + return MetaTensor(elem, device=device) + else: + return elem + with self._mode: return super().run(*tree_map(wrap_fn, args)) diff --git a/colossalai/_analyzer/fx/tracer/bias_addition.py b/colossalai/_analyzer/fx/tracer/bias_addition.py index 495678501..1e75b47ca 100644 --- a/colossalai/_analyzer/fx/tracer/bias_addition.py +++ b/colossalai/_analyzer/fx/tracer/bias_addition.py @@ -21,69 +21,111 @@ def linear_impl(input, weight, bias=None): @register_tracer_impl(F.conv1d, name='_bias_addition_impl') -def conv1d_impl(input, weight, **kwargs): - bias = getattr(kwargs, 'bias', None) +def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1): if bias is None: - return F.conv1d(input, weight, **kwargs) + return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: - new_kwargs = kwargs - new_kwargs['bias'] = None - return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1)) + return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( + (-1, 1)) @register_tracer_impl(F.conv2d, name='_bias_addition_impl') -def conv2d_impl(input, weight, **kwargs): - bias = getattr(kwargs, 'bias', None) +def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1): if bias is None: - return F.conv2d(input, weight, **kwargs) + return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: - new_kwargs = kwargs - new_kwargs['bias'] = None - return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1)) + return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( + (-1, 1, 1)) @register_tracer_impl(F.conv3d, name='_bias_addition_impl') -def conv3d_impl(input, weight, **kwargs): - bias = getattr(kwargs, 'bias', None) +def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1): if bias is None: - return F.conv3d(input, weight, **kwargs) + return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: - new_kwargs = kwargs - new_kwargs['bias'] = None - return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1)) + return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( + (-1, 1, 1, 1)) @register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl') -def conv_transpose1d_impl(input, weight, **kwargs): - bias = getattr(kwargs, 'bias', None) +def conv_transpose1d_impl(input, + weight, + bias=None, + stride=_single(1), + padding=_single(0), + output_padding=_single(0), + groups=1, + dilation=_single(1)): if bias is None: - return F.conv_transpose1d(input, weight, **kwargs) + return F.conv_transpose1d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) else: - new_kwargs = kwargs - new_kwargs['bias'] = None - return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1)) + return F.conv_transpose1d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + bias.reshape((-1, 1)) @register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl') -def conv_transpose2d_impl(input, weight, **kwargs): - bias = getattr(kwargs, 'bias', None) +def conv_transpose2d_impl(input, + weight, + bias=None, + stride=_pair(1), + padding=_pair(0), + output_padding=_pair(0), + groups=1, + dilation=_pair(1)): if bias is None: - return F.conv_transpose2d(input, weight, **kwargs) + return F.conv_transpose2d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) else: - new_kwargs = kwargs - new_kwargs['bias'] = None - return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1)) + return F.conv_transpose2d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + bias.reshape((-1, 1, 1)) @register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl') -def conv_transpose3d_impl(input, weight, **kwargs): - bias = getattr(kwargs, 'bias', None) +def conv_transpose3d_impl(input, + weight, + bias=None, + stride=_triple(1), + padding=_triple(0), + output_padding=_triple(0), + groups=1, + dilation=_triple(1)): if bias is None: - return F.conv_transpose3d(input, weight, **kwargs) + return F.conv_transpose3d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) else: - new_kwargs = kwargs - new_kwargs['bias'] = None - return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1)) + return F.conv_transpose3d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + bias.reshape((-1, 1, 1, 1)) @register_tracer_impl(torch.addmm, name='_bias_addition_impl') diff --git a/colossalai/_analyzer/fx/tracer/tracer.py b/colossalai/_analyzer/fx/tracer/tracer.py index 1a247449f..6958a00a6 100644 --- a/colossalai/_analyzer/fx/tracer/tracer.py +++ b/colossalai/_analyzer/fx/tracer/tracer.py @@ -155,7 +155,7 @@ class ColoTracer(Tracer): def create_node(self, *args, **kwargs) -> Node: node = super().create_node(*args, **kwargs) - n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions)) + n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions)) return node def trace(self, diff --git a/colossalai/auto_parallel/meta_profiler/__init__.py b/colossalai/auto_parallel/meta_profiler/__init__.py index bfd361951..3741d8e5a 100644 --- a/colossalai/auto_parallel/meta_profiler/__init__.py +++ b/colossalai/auto_parallel/meta_profiler/__init__.py @@ -1,3 +1,3 @@ from .meta_registry import * -from .metainfo import * from .registry import meta_register +from .shard_metainfo import * diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index faeed9f29..0f2e9e44f 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -2,9 +2,9 @@ from typing import Callable, List, Tuple import torch +from colossalai._analyzer._subclasses.flop_tensor import ewise_flop_counter as elementwise_flop_counter +from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import elementwise_flop_counter from ..registry import meta_register diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index 281a92c0d..e45174851 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -2,9 +2,9 @@ from typing import List, Tuple import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION from ..registry import meta_register @@ -17,7 +17,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train """Meta information generator for binary elementwise operations NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`, - they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify + they will be discarded right after add operation is done. We create a simple API in `ShardMetaInfo` class to identify this behavior, it is critical for better memory estimation. Returns: diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py index d1bb6e7fa..4336bf683 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -2,6 +2,8 @@ from typing import Callable, Dict, List, Tuple, Union import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( MemoryCost, OperationData, @@ -10,8 +12,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( StrategiesVector, TrainCycleItem, ) -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from colossalai.tensor.sharding_spec import ShardingSpec from ..registry import meta_register @@ -110,18 +110,18 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # calculate memory cost # TODO: use profiler to check conv temp memory # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost( - activation=activation_size([input_tensor, output_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor), - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) + if has_bias else compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0) - bwd_memory_cost = MemoryCost( - activation=activation_size([input_tensor, weight_tensor, bias_tensor]) - if has_bias else activation_size([input_tensor, weight_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor), - temp=0, - buffer=0) + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) + if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) + if has_bias else compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0) # total cost is the sum of forward and backward cost total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py index 2997f31ad..d5d80f5b3 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py @@ -2,9 +2,9 @@ from typing import List, Tuple import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from ..registry import meta_register @@ -34,11 +34,11 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will # have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume # that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory - fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0) - bwd_memory_cost = MemoryCost(activation=activation_size([weight_tensor]), parameter=0, temp=0, buffer=0) + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0) total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py index 617375721..7697fc6c3 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -3,6 +3,8 @@ from typing import Callable, Dict, List, Tuple, Union import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( MemoryCost, OperationData, @@ -11,8 +13,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( StrategiesVector, TrainCycleItem, ) -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from colossalai.tensor.sharding_spec import ShardingSpec from ..registry import meta_register @@ -112,14 +112,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # NOTE: Linear don't have buffer and temp in forward and backward phase # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]), + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), temp=0, buffer=0) # the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0 - bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]), + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), temp=0, buffer=0) @@ -148,14 +148,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # NOTE: Linear don't have buffer and temp in forward and backward phase # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), - parameter=activation_size(weight_tensor), + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes(weight_tensor), temp=0, buffer=0) # the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0 - bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor]), - parameter=activation_size(weight_tensor), + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]), + parameter=compute_size_in_bytes(weight_tensor), temp=0, buffer=0) @@ -210,48 +210,48 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # Check dimension if all(len(tensor.shape) == 1 for tensor in input_tensors): # Dot - fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors) + fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors) bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2 - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0) + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1: # gemv case 1: matrix-vector multiplication # & # batched gemv case 1: batched matrix-vector multiplication - fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default]( + fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors) # combine the dimensions of output bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( [output_tensors[0].reshape(-1), input_tensors[1]], output_tensors) + \ - flop_mapping[torch.ops.aten.mv.default]( + flop_mapping[torch.ops.aten.matmul.default]( [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)], output_tensors) - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0) + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2: # gemv case 2: vector-matrix multiplication - fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors) + fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors) bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \ - flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors) - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, - temp=activation_size(input_tensors[1]), + temp=compute_size_in_bytes(input_tensors[1]), buffer=0) elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3: # batched gemv case 2: vector-batched matrix multiplication - fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default]( + fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]], [output_tensors[0].reshape(-1)]) @@ -260,15 +260,15 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L [output_tensors[0].reshape(-1), input_tensors[0]], output_tensors ) + \ - flop_mapping[torch.ops.aten.mv.default]( + flop_mapping[torch.ops.aten.matmul.default]( [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)], output_tensors ) - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]])) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]), + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]])) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), parameter=0, - temp=activation_size(input_tensors[1]), + temp=compute_size_in_bytes(input_tensors[1]), buffer=0) elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2: @@ -287,8 +287,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])] ) - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0) + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3: # batched gemm case 2: matrix-batched matrix multiplication @@ -306,11 +306,12 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])] ) - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]), - temp=activation_size(output_tensors)) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]), + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) + + compute_size_in_bytes(input_tensors[1]), + temp=compute_size_in_bytes(output_tensors)) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), parameter=0, - temp=activation_size(input_tensors[1]) + activation_size(output_tensors)) + temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors)) elif all(len(tensor.shape) >= 3 for tensor in input_tensors): # Batched matrix-batched matrix multiplication @@ -351,8 +352,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)] ) - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors)) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors)) + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors)) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors)) else: # Case 2: batch dimensions are different @@ -381,10 +382,10 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L ) fwd_mem_cost = MemoryCost( - activation=activation_size([output_tensors[0], extended_input_0, extended_input_1])) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) - - activation_size([extended_input_0, extended_input_1]), - temp=activation_size([extended_input_0, extended_input_1])) + activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) - + compute_size_in_bytes([extended_input_0, extended_input_1]), + temp=compute_size_in_bytes([extended_input_0, extended_input_1])) # compute cost compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py index 4634d3ccd..12874810b 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py @@ -4,8 +4,6 @@ from typing import List, Tuple import torch from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from ..registry import meta_register diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py index 3a1db396e..b872fdc8b 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py @@ -2,6 +2,8 @@ from typing import Callable, Dict, List, Tuple, Union import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( MemoryCost, OperationData, @@ -10,8 +12,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( StrategiesVector, TrainCycleItem, ) -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from colossalai.tensor.sharding_spec import ShardingSpec from ..registry import meta_register @@ -77,17 +77,18 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt # calculate memory cost # the fwd activation cost is output plus saved mean and saved inv std # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, mean_tensor, var_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]), + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( + [input_tensor, output_tensor, mean_tensor, var_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), temp=0, - buffer=activation_size([mean_tensor, var_tensor])) + buffer=compute_size_in_bytes([mean_tensor, var_tensor])) # the bwd memory cost is quite tricky here, BatchNorm will remove saved mean # and saved inv std during backward phase - bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]), - temp=activation_size([mean_tensor, var_tensor]), - buffer=activation_size([mean_tensor, var_tensor])) + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=compute_size_in_bytes([mean_tensor, var_tensor]), + buffer=compute_size_in_bytes([mean_tensor, var_tensor])) # total cost is the sum of forward and backward cost total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, @@ -131,15 +132,16 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # memory cost # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, weight_tensor, bias_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]), + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( + [input_tensor, output_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), temp=0, - buffer=activation_size([running_mean, running_var])) + buffer=compute_size_in_bytes([running_mean, running_var])) - bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]), - temp=activation_size([running_mean, running_var]), - buffer=activation_size([running_mean, running_var])) + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=compute_size_in_bytes([running_mean, running_var]), + buffer=compute_size_in_bytes([running_mean, running_var])) total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py index 21272ea09..d785dfcca 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py @@ -2,9 +2,9 @@ from typing import List, Tuple import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from ..registry import meta_register @@ -52,8 +52,8 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # calculate memory cost - fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(output_tensor)) - bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(input_tensor)) + fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(output_tensor)) + bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(input_tensor)) # total cost total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation) @@ -114,11 +114,11 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, # calculate memory cost # NOTE: the index matrix will be discarded in backward phase # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_mem_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, index_matrix])) + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix])) # temp memory for backward is the index matrix to be discarded - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor) - activation_size(index_matrix), - temp=activation_size(index_matrix)) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix), + temp=compute_size_in_bytes(index_matrix)) # total cost total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py index 332e649d2..97fe3c619 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py @@ -2,9 +2,9 @@ from typing import Callable, List, Tuple import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from ..registry import meta_register @@ -35,11 +35,11 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f # memory costs # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_mem_cost = MemoryCost(activation=activation_size(outputs) * 2, parameter=0, temp=0, buffer=0) + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=activation_size(outputs) * bwd_mem_out_factor, + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor, parameter=0, - temp=activation_size(outputs) * bwd_mem_tmp_factor, + temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor, buffer=0) total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py index c67eb40bc..5cba1b5b6 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py @@ -2,9 +2,9 @@ from typing import List, Tuple import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from ..registry import meta_register diff --git a/colossalai/auto_parallel/meta_profiler/metainfo.py b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py similarity index 94% rename from colossalai/auto_parallel/meta_profiler/metainfo.py rename to colossalai/auto_parallel/meta_profiler/shard_metainfo.py index 44b1882e0..0eee908b4 100644 --- a/colossalai/auto_parallel/meta_profiler/metainfo.py +++ b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py @@ -15,11 +15,11 @@ from colossalai.tensor.sharding_spec import ShardingSpec from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION from .registry import meta_register -__all__ = ['MetaInfo'] +__all__ = ['ShardMetaInfo'] -class MetaInfo: - """MetaInfo class +class ShardMetaInfo: + """ShardMetaInfo class This class is used to store meta info based on sharding strategy and the given target function. """ @@ -46,9 +46,9 @@ class MetaInfo: # target function self._target = target - # compute metainfo if possible + # compute shard_metainfo if possible if self._strategy is not None and self._target is not None: - self.compute_metainfo() + self.compute_shard_metainfo() @property def strategy(self) -> ShardingStrategy: @@ -62,13 +62,13 @@ class MetaInfo: def strategy(self, strategy: ShardingStrategy) -> None: self._strategy = strategy if self._strategy is not None and self._target is not None: - self.compute_metainfo() + self.compute_shard_metainfo() @target.setter def target(self, target: Callable) -> None: self._target = target if self._strategy is not None and self._target is not None: - self.compute_metainfo() + self.compute_shard_metainfo() def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec): """ @@ -93,7 +93,7 @@ class MetaInfo: return op_data - def compute_metainfo(self): + def compute_shard_metainfo(self): """ Compute meta info based on sharding strategy and the given target function. """ diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py index ab3acb056..ffda58e06 100644 --- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -4,7 +4,7 @@ import torch from torch.fx import GraphModule from torch.fx.node import Node -from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai.auto_parallel.meta_profiler import ShardMetaInfo from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.tensor.comm_spec import CommSpec @@ -14,15 +14,15 @@ from colossalai.tensor.sharding_spec import ShardingSpec shape_consistency_manager = ShapeConsistencyManager() -def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, - target_sharding_spec: ShardingSpec) -> MetaInfo: +def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, + target_sharding_spec: ShardingSpec) -> ShardMetaInfo: # get comm_action_sequence and total_cost from shape_consistency_manager _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( origin_sharding_spec, target_sharding_spec) - meta_info = MetaInfo() + meta_info = ShardMetaInfo() # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel - # get mem cost for MetaInfo + # get mem cost for ShardMetaInfo mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) # extract user that has _meta_data and extract element length input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data')) @@ -36,12 +36,12 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, meta_info.memory_cost = mem_cost - # get computation cost for MetaInfo + # get computation cost for ShardMetaInfo meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, total_cost['backward'] * element_length, total_cost['total'] * element_length) - # get tensor shape for MetaInfo + # get tensor shape for ShardMetaInfo origin_sharding_spec: ShardingSpec target_sharding_spec: ShardingSpec input_shape = origin_sharding_spec.get_sharded_shape_per_device() @@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, return meta_info -def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo: +def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> ShardMetaInfo: """ This method is used to construct `MetaInto` for shape consistency node """ @@ -65,17 +65,17 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) - origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ user_node_index] - return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) + return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) -def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo: +def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> ShardMetaInfo: # extract node_index and op_data_name node_index, op_data_name = node.args[2], node.args[3] comm_action = comm_actions_dict[node_index][op_data_name] if isinstance(comm_action.comm_spec, CommSpec): # this case is for all_reduce, there will be no memory cost - meta_info = MetaInfo() + meta_info = ShardMetaInfo() meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost) output_node = next(n for n in node.users if hasattr(n, '_meta_data')) element_length = output_node._meta_data.element_size() @@ -93,7 +93,7 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M # this case will be handled by shape consistency manager origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[ 'tgt_spec'] - meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) + meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) return meta_info @@ -105,9 +105,9 @@ def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_di """ for node in gm.graph.nodes: if node.target == runtime_apply: - setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) + setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) elif node.target == runtime_comm_spec_apply: - setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) + setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) else: pass return gm diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index f7e07ef1e..bc0960483 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -7,7 +7,7 @@ import torch.fx from torch.fx import GraphModule from torch.fx.node import Node -from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai.auto_parallel.meta_profiler import ShardMetaInfo from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS from colossalai.fx._compatibility import compatibility from colossalai.fx.profiler import GraphInfo @@ -96,12 +96,12 @@ class MetaInfoProp: """ Handle other kind of nodes """ - assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}" + assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}" graph_info = GraphInfo() - meta_info = node.best_metainfo - meta_info: MetaInfo + meta_info = node.best_strategy_info + meta_info: ShardMetaInfo - # set data_ptr for input_tensor in MetaInfo class + # set data_ptr for input_tensor in ShardMetaInfo class input_tensors: List[torch.Tensor] = meta_info.fwd_in buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer output_tensors: List[torch.Tensor] = meta_info.fwd_out diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index 9d83f1057..a473bb6e9 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -4,7 +4,7 @@ from typing import Dict, List import torch from torch.fx.node import Node -from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai._analyzer.fx.node_util import MetaInfo from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( CommAction, CommType, @@ -128,9 +128,10 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): runtime_apply, args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index)) - if 'activation_checkpoint' in user_node.meta: - shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint'] - + if hasattr(user_node.meta['info'], 'activation_checkpoint'): + MetaInfo(shape_consistency_node, + mod_dir=user_node.meta['info'].mod_dir, + activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint)) new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs) # the origin node may be a positional argument or key word argument of user node @@ -210,9 +211,10 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): # substitute the origin node with comm_spec_apply_node new_kwargs[str(node)] = comm_spec_apply_node user.kwargs = new_kwargs - - if 'activation_checkpoint' in node.meta: - comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint'] + if hasattr(node.meta['info'], 'activation_checkpoint'): + MetaInfo(comm_spec_apply_node, + mod_dir=node.meta['info'].mod_dir, + activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) return gm diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 3be308422..e1d0c6272 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -6,6 +6,7 @@ import torch from torch.fx import symbolic_trace from torch.fx.node import Node +from colossalai._analyzer.fx.node_util import MetaInfo from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( CommAction, @@ -74,9 +75,9 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name( str(node)) - # attach the corresponding metainfo if node has the attribute `metainfo_vector` - if hasattr(node, 'metainfo_vector'): - setattr(node, 'best_metainfo', node.metainfo_vector[strategy_index]) + # attach the corresponding metainfo if node has the attribute `strategies_info` + if hasattr(node, 'strategies_info'): + setattr(node, 'best_strategy_info', node.strategies_info[strategy_index]) # the dict to get input sharding specs of user node sharding_spec_convert_dict = {} @@ -172,8 +173,11 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh # It will be used to replace the original node with processing node in slice object node_pairs[node] = size_processing_node size_processing_node._meta_data = node._meta_data - if 'activation_checkpoint' in node.meta: - size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint'] + + if hasattr(node.meta['info'], 'activation_checkpoint'): + MetaInfo(size_processing_node, + mod_dir=node.meta['info'].mod_dir, + activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) user_list = list(node.users.keys()) for user in user_list: diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index 60472eee5..b406ca6fb 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -6,6 +6,10 @@ import torch.nn as nn from torch.fx import GraphModule from torch.fx.graph import Graph +from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference @@ -13,8 +17,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec @@ -126,6 +128,7 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc def transform_to_sharded_model(gm: ColoGraphModule, + meta_args: Dict, solution: List[int], device_mesh: DeviceMesh, strategies_constructor: StrategiesConstructor, @@ -142,6 +145,7 @@ def transform_to_sharded_model(gm: ColoGraphModule, strategies_constructor, overlap=overlap) gm = runtime_apply_pass(gm) + shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict) gm.recompile() sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict) @@ -243,10 +247,13 @@ def initialize_model(model: nn.Module, solution will be used to debug or help to analyze the sharding result. Therefore, we will not just return a series of integers, but return the best strategies. ''' - tracer = ColoTracer(trace_act_ckpt=True) + tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True) graph = tracer.trace(root=model, meta_args=meta_args) + graph.set_codegen(ActivationCheckpointCodeGen()) gm = ColoGraphModule(model, graph, model.__class__.__name__) + + shape_prop_pass(gm, *meta_args.values()) gm.recompile() strategies_constructor = build_strategy_constructor(graph, @@ -261,7 +268,9 @@ def initialize_model(model: nn.Module, if save_solver_solution: torch.save(solution, solution_path) - gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap) + gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor, + overlap) + model_to_return = ModuleWrapper(gm, *sharding_spec_dicts) if return_solution: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py index 57b623b01..cb1bb36b7 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py @@ -2,8 +2,6 @@ from typing import Dict, List import torch -from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo - from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector from .node_handler import MetaInfoModuleHandler, ModuleHandler from .registry import operator_registry diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 136e57c5e..ab391ebfa 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union import torch from torch.fx.node import Node -from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register +from colossalai.auto_parallel.meta_profiler.shard_metainfo import ShardMetaInfo, meta_register from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, @@ -258,7 +258,7 @@ class MetaInfoNodeHandler(NodeHandler): def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: """ This method is inherited from NodeHandler. It will register the strategies first, - and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class. + and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class. """ super().register_strategy(compute_resharding_cost=compute_resharding_cost) target = self.get_target_function() @@ -266,15 +266,15 @@ class MetaInfoNodeHandler(NodeHandler): # is not patched, we will use the default cost model to compute the cost. # TODO: patch all torch functions and modules to make it clean if meta_register.has(target.__class__) or meta_register.has(target): - metainfo_vector = [] + strategies_info = [] for strategy in self.strategies_vector: - metainfo = MetaInfo(strategy, target) + metainfo = ShardMetaInfo(strategy, target) strategy.compute_cost = metainfo.compute_cost strategy.memory_cost = metainfo.memory_cost - metainfo_vector.append(metainfo) + strategies_info.append(metainfo) # attach metainfos to the handler - setattr(self, "metainfo_vector", metainfo_vector) + setattr(self, "strategies_info", strategies_info) else: logger = get_dist_logger() @@ -313,7 +313,7 @@ class MetaInfoModuleHandler(ModuleHandler): def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: """ This method is inherited from NodeHandler. It will register the strategies first, - and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class. + and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class. """ super().register_strategy(compute_resharding_cost=compute_resharding_cost) target = self.get_target_function() @@ -321,15 +321,15 @@ class MetaInfoModuleHandler(ModuleHandler): # is not patched, we will use the default cost model to compute the cost. # TODO: patch all torch functions and modules to make it clean if meta_register.has(target.__class__) or meta_register.has(target): - metainfo_vector = [] + strategies_info = [] for strategy in self.strategies_vector: - metainfo = MetaInfo(strategy, target) + metainfo = ShardMetaInfo(strategy, target) strategy.compute_cost = metainfo.compute_cost strategy.memory_cost = metainfo.memory_cost - metainfo_vector.append(metainfo) + strategies_info.append(metainfo) # attach metainfos to the handler - setattr(self, "metainfo_vector", metainfo_vector) + setattr(self, "strategies_info", strategies_info) else: logger = get_dist_logger() diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 59ead1ca8..044a8ac84 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -137,9 +137,9 @@ class StrategiesConstructor: shard_option=self.solver_options.shard_option, solver_perference=self.solver_options.solver_perference) handler.register_strategy() - # attach metainfo_vector to node - if hasattr(handler, 'metainfo_vector'): - setattr(node, 'metainfo_vector', handler.metainfo_vector) + # attach strategies_info to node + if hasattr(handler, 'strategies_info'): + setattr(node, 'strategies_info', handler.strategies_info) # call_function node elif node.op == 'call_function': @@ -150,9 +150,9 @@ class StrategiesConstructor: shard_option=self.solver_options.shard_option, solver_perference=self.solver_options.solver_perference) handler.register_strategy() - # attach metainfo_vector to node - if hasattr(handler, 'metainfo_vector'): - setattr(node, 'metainfo_vector', handler.metainfo_vector) + # attach strategies_info to node + if hasattr(handler, 'strategies_info'): + setattr(node, 'strategies_info', handler.strategies_info) # call_method node elif node.op == 'call_method': @@ -163,9 +163,9 @@ class StrategiesConstructor: shard_option=self.solver_options.shard_option, solver_perference=self.solver_options.solver_perference) handler.register_strategy() - # attach metainfo_vector to node - if hasattr(handler, 'metainfo_vector'): - setattr(node, 'metainfo_vector', handler.metainfo_vector) + # attach strategies_info to node + if hasattr(handler, 'strategies_info'): + setattr(node, 'strategies_info', handler.strategies_info) # output node elif node.op == 'output': diff --git a/tests/test_analyzer/__init__.py b/tests/test_analyzer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py index 349483008..7d4fd844a 100644 --- a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py +++ b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py @@ -1,10 +1,12 @@ +import pytest import torch import torch.nn.functional as F +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec @@ -33,6 +35,7 @@ def recover_narrow(gm, narrow_node): return gm +@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') def test_size_value_converting_pass(): model = TestModule() physical_mesh_id = torch.arange(0, 4) @@ -40,14 +43,14 @@ def test_size_value_converting_pass(): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) meta_args = {'x': torch.rand(4, 8).to('meta')} input = torch.rand(4, 8) - tracer = ColoTracer() + tracer = ColoTracer(bias_addition_split=True) graph = tracer.trace(root=model, meta_args=meta_args) - x_node = list(graph.nodes)[0] x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) setattr(x_node, 'sharding_spec', x_sharding_spec) gm = ColoGraphModule(model, graph) gm = insert_narrow(gm, x_node) + shape_prop_pass(gm, *meta_args.values()) gm.recompile() size = gm(input) assert size == torch.Size([2, 8]) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py index f43885a6a..6d1b28912 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -4,7 +4,12 @@ import pytest import torch import torch.multiprocessing as mp -from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +try: + from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False +except: + NO_CODEGEN = True + from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -77,6 +82,7 @@ def check_conv_module(rank, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') @pytest.mark.dist @rerun_if_address_is_in_use() def test_bias_addition_module(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py index 0b42722fe..7a4c8d32e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py @@ -8,13 +8,15 @@ import torch.nn as nn from torch.utils.checkpoint import checkpoint from transformers.pytorch_utils import Conv1D -from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +try: + from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False +except: + NO_CODEGEN = True + from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.tracer import ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.utils import free_port @@ -43,6 +45,7 @@ def check_act_ckpt(rank, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE) + input = torch.rand(1, 64, HIDDEN_SIZE) input_sample = { 'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'), } @@ -54,10 +57,11 @@ def check_act_ckpt(rank, world_size, port): gm = initialize_model(model, input_sample, device_mesh) code = gm.module.graph.python_code('self').src assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code - assert "view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)" in code + assert "view_3 = torch.utils.checkpoint.checkpoint(self.checkpoint_0, view_1, comm_actions_dict, use_reentrant=False)" in code @run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') @pytest.mark.dist @rerun_if_address_is_in_use() def test_mlp_layer(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py index e4982a5d7..7c3277c69 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py @@ -6,7 +6,12 @@ import torch import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP -from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +try: + from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False +except: + NO_CODEGEN = True + from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -93,6 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') @pytest.mark.dist @rerun_if_address_is_in_use() def test_compatibility_with_ddp(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index 9879ae461..e4435a049 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -6,7 +6,12 @@ import torch import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP -from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +try: + from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False +except: + NO_CODEGEN = True + from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -101,6 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') @pytest.mark.dist @rerun_if_address_is_in_use() def test_auto_parallel_with_gemini(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py index 90301521f..e7fccad36 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py @@ -5,8 +5,11 @@ import torch.nn as nn from torch.fx import GraphModule from transformers.pytorch_utils import Conv1D +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +# from colossalai.fx.tracer.tracer import ColoTracer +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks -from colossalai.fx.tracer.tracer import ColoTracer from colossalai.testing import parameterize from colossalai.testing.pytest_wrapper import run_on_environment_flag @@ -83,11 +86,12 @@ def test_repeat_blocks(model_cls): model = model_cls(4 * HIDDEN_DIM, HIDDEN_DIM, NUM_REPEAT_BLOCKS) - tracer = ColoTracer() + tracer = ColoTracer(bias_addition_split=True) input_sample = {'x': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta')} graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) gm.recompile() node_list = list(graph.nodes) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py index ebeef9870..8688890ef 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -10,15 +10,23 @@ import torch.multiprocessing as mp import transformers from torch.fx import GraphModule -from colossalai.auto_parallel.tensor_shard.initialize import ( - ModuleWrapper, - build_strategy_constructor, - solve_solution, - transform_to_sharded_model, -) +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +# from colossalai.fx.tracer.tracer import ColoTracer +from colossalai._analyzer.fx.tracer.tracer import ColoTracer + +try: + from colossalai.auto_parallel.tensor_shard.initialize import ( + ModuleWrapper, + build_strategy_constructor, + solve_solution, + transform_to_sharded_model, + ) + NO_CODEGEN = False +except: + NO_CODEGEN = True + from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.shape_consistency import to_global @@ -52,9 +60,8 @@ def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, tor param_sharding_spec = best_sharding_spec_dict[new_name] grad_to_compare = copy.deepcopy(param_grad) param_grad_global = to_global(grad_to_compare, param_sharding_spec) - try: - assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-03) + assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-05) except: difference = param_grad_global - origin_param_grad avg_diff = difference.abs().sum() / difference.numel() @@ -66,7 +73,7 @@ def check_attention_layer(rank, model_cls, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - config = transformers.GPT2Config(n_position=64, n_layer=1, n_head=16, n_embd=HIDDEN_DIM) + config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) if model_cls == GPT2MLP: model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda') @@ -111,15 +118,17 @@ def check_attention_layer(rank, model_cls, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - tracer = ColoTracer() + tracer = ColoTracer(bias_addition_split=True) graph = tracer.trace(root=model, meta_args=meta_input_sample) gm = GraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *meta_input_sample.values()) gm.recompile() strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard') solution = solve_solution(gm, strategies_constructor, memory_budget=-1) - gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor) + gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh, + strategies_constructor) gm = ModuleWrapper(gm, *sharding_spec_dicts) nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] @@ -176,6 +185,7 @@ def check_attention_layer(rank, model_cls, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason="no codegen module") @pytest.mark.dist @parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) @rerun_if_address_is_in_use() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py index 4adb4fbaf..5f0688d5f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -3,11 +3,12 @@ import torch.nn as nn import transformers from torch.fx import GraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.testing import parameterize from colossalai.testing.pytest_wrapper import run_on_environment_flag @@ -21,7 +22,7 @@ HIDDEN_DIM = 384 @run_on_environment_flag(name='AUTO_PARALLEL') @parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) def test_self_attention_block(model_cls): - config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM) + config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) if model_cls == GPT2MLP: model = model_cls(intermediate_size=4 * config.hidden_size, config=config) else: @@ -33,7 +34,7 @@ def test_self_attention_block(model_cls): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) shape_consistency_manager = ShapeConsistencyManager() - tracer = ColoTracer() + tracer = ColoTracer(bias_addition_split=True) if model_cls == GPT2MLP: input_sample = { 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), @@ -52,6 +53,7 @@ def test_self_attention_block(model_cls): graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) print(gm.graph) gm.recompile() solver_options = SolverOptions() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py index f5de7bf70..8d4212438 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py @@ -1,8 +1,11 @@ +import pytest import torch import torch.nn as nn +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser -from colossalai.fx import ColoGraphModule, ColoTracer class LinearModel(nn.Module): @@ -22,15 +25,14 @@ class LinearModel(nn.Module): return out +@pytest.mark.skip('meta tensor has some bugs in 1.11') def test_liveness_analysis(): model = LinearModel() - tracer = ColoTracer() - graph = tracer.trace(model, - meta_args={ - 'x1': torch.rand(4, 4, device='meta'), - 'x2': torch.rand(4, 4, device='meta') - }) + tracer = ColoTracer(bias_addition_split=True) + meta_args = {'x1': torch.rand(4, 4, device='meta'), 'x2': torch.rand(4, 4, device='meta')} + graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__) + shape_prop_pass(gm, *meta_args.values()) graph_analyser = GraphAnalyser(gm) liveness_list = graph_analyser.liveness_analysis() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py index 2fb130654..5f3d2df50 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py @@ -24,7 +24,7 @@ from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py index e9c0601eb..ddc8e3c6a 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -17,7 +17,7 @@ from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register class MyModule(nn.Module): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py index fd29c63fb..1242b9db0 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py @@ -24,7 +24,7 @@ from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py index 9d3ab9c82..d3342d310 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py @@ -23,7 +23,7 @@ from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register def _batchnorm_module_mem_test(rank, world_size, port): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py index a0ab66fdc..a544e9a3c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py @@ -24,7 +24,7 @@ from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register class SplitModule(nn.Module): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py index 20156f9ab..2ae13ea2b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py @@ -22,7 +22,7 @@ from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py index 60ecd1dd9..4ca85d34d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py @@ -5,16 +5,19 @@ from typing import Dict, List import torch from torch.fx import GraphModule +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +# from colossalai.fx.tracer.tracer import ColoTracer +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType, TrainCycleItem from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo def mem_test_for_node_strategy(rank: int, @@ -30,14 +33,16 @@ def mem_test_for_node_strategy(rank: int, model_to_shard, args_to_shard, kwargs_to_shard = copy.deepcopy(model), copy.deepcopy(input_args), copy.deepcopy( input_kwargs) - tracer = ColoTracer() + tracer = ColoTracer(bias_addition_split=True) input_sample = {} for input_arg, meta_arg_name in zip(input_args, meta_arg_names): input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') for meta_kwarg_name, input_kwarg in input_kwargs.items(): input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') graph = tracer.trace(root=model_to_shard, meta_args=input_sample) - gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) + gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) + gm.recompile() solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() @@ -108,10 +113,10 @@ def mem_test_for_node_strategy(rank: int, # estimated memory if target_node.op == "call_module": - metainfo = MetaInfo(target_node.strategies_vector[strategy_index], - target_node.graph.owning_module.get_submodule(target_node.target)) + metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], + target_node.graph.owning_module.get_submodule(target_node.target)) else: - metainfo = MetaInfo(target_node.strategies_vector[strategy_index], target_node.target) + metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], target_node.target) print("estimated memory:") print( diff --git a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py deleted file mode 100644 index 92f011ba3..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py +++ /dev/null @@ -1,126 +0,0 @@ -import torch - -from colossalai.auto_parallel.tensor_shard.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType -from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.testing.pytest_wrapper import run_on_environment_flag - - -def _param_resharding_cost_assertion(node): - for strategy in node.strategies_vector: - for prev_node, resharding_cost in strategy.resharding_costs.items(): - if strategy.get_op_data_by_name(str(prev_node)).type == OperationDataType.PARAM: - for cost in resharding_cost: - assert cost.fwd == 0 - assert cost.bwd == 0 - assert cost.total == 0 - - -class LinearModel(torch.nn.Module): - - def __init__(self, in_features, out_features): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features) - - def forward(self, x): - x = self.linear(x) - x = x * 2 - - return x - - -class ConvModel(torch.nn.Module): - - def __init__(self, in_channels, out_channels, kernel_size, bias=True): - super().__init__() - self.conv = torch.nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - bias=bias) - - def forward(self, x): - x = self.conv(x) - x = x * 2 - - return x - - -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_linear_module(): - model = LinearModel(4, 8) - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - tracer = ColoTracer() - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %linear_weight : [#users=1] = get_attr[target=linear.weight] - # %linear_bias : [#users=1] = get_attr[target=linear.bias] - # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {}) - # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) - # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) - # return mul - graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')}) - # def forward(self, x : torch.Tensor): - # linear_weight = self.linear.weight - # linear_bias = self.linear.bias - # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None - # add = linear + linear_bias; linear = linear_bias = None - # mul = add * 2; add = None - # return mul - gm = ColoGraphModule(model, graph) - gm.recompile() - node_list = list(graph.nodes) - - solver_options = SolverOptions() - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - strategies_constructor.build_strategies_and_cost() - linear_node = node_list[3] - _param_resharding_cost_assertion(linear_node) - - -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_conv_module(): - model = ConvModel(3, 6, 2) - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - tracer = ColoTracer() - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %conv_weight : [#users=1] = get_attr[target=conv.weight] - # %conv_bias : [#users=1] = get_attr[target=conv.bias] - # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) - # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) - # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) - # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) - # return mul - graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) - # def forward(self, x : torch.Tensor): - # conv_weight = self.conv.weight - # conv_bias = self.conv.bias - # conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None - # view = conv_bias.view([1, -1, 1, 1]); conv_bias = None - # add = conv2d + view; conv2d = view = None - # mul = add * 2; add = None - # return mul - gm = ColoGraphModule(model, graph) - - gm.recompile() - node_list = list(graph.nodes) - conv_node = node_list[3] - solver_options = SolverOptions() - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - strategies_constructor.build_strategies_and_cost() - _param_resharding_cost_assertion(conv_node) - - -if __name__ == '__main__': - test_linear_module() - test_conv_module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py deleted file mode 100644 index 24a3ae5b4..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py +++ /dev/null @@ -1,86 +0,0 @@ -import copy -from functools import partial - -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn - -from colossalai.auto_parallel.tensor_shard.initialize import initialize_model -from colossalai.device.device_mesh import DeviceMesh -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port - - -class ConvModel(nn.Module): - - def __init__(self, c_in, c_out): - super().__init__() - self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False) - - def forward(self, x): - x = self.conv(x) - x = torch.flatten(x) - return x - - -def check_apply(rank, world_size, port): - disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - input = torch.rand(4, 4, 4, 4).cuda() - test_input = copy.deepcopy(input) - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) - # return conv - model = ConvModel(4, 4).cuda() - test_model = copy.deepcopy(model) - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 4, 4, 4).to('meta')} - gm = initialize_model(model, meta_args, device_mesh) - - output = gm(input) - origin_output = test_model(test_input) - assert output.equal(origin_output) - origin_loss = origin_output.sum() - loss = output.sum() - - origin_loss.backward() - loss.backward() - - grad_0 = test_model.conv.weight.grad.narrow(0, 0, 1) - grad_1 = test_model.conv.weight.grad.narrow(0, 1, 1) - grad_2 = test_model.conv.weight.grad.narrow(0, 2, 1) - grad_3 = test_model.conv.weight.grad.narrow(0, 3, 1) - - if rank == 0: - assert_close(gm.module.conv.weight.grad.data, grad_0.data) - elif rank == 1: - assert_close(gm.module.conv.weight.grad.data, grad_1.data) - elif rank == 2: - assert_close(gm.module.conv.weight.grad.data, grad_2.data) - elif rank == 3: - assert_close(gm.module.conv.weight.grad.data, grad_3.data) - else: - raise ValueError(f'rank {rank} does not exist.') - - -# skip this test due to pulp not installed in CI environment -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_apply(): - world_size = 4 - run_func = partial(check_apply, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_apply() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index bbfc3e1fc..fb47baab9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -2,11 +2,13 @@ import torch from torch.fx import GraphModule from torchvision.models import resnet50 +from colossalai._analyzer.fx.passes import shape_prop_pass +# from colossalai.fx.tracer.tracer import ColoTracer +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.testing.pytest_wrapper import run_on_environment_flag @@ -20,7 +22,7 @@ def test_cost_graph(): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) shape_consistency_manager = ShapeConsistencyManager() - tracer = ColoTracer() + tracer = ColoTracer(bias_addition_split=True) model = resnet50(num_classes=100000) input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')} @@ -50,6 +52,7 @@ def test_cost_graph(): # %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {}) # return fc gm = GraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) gm.recompile() solver_options = SolverOptions()