From 5458da5c3c5454fc22c389eb462d9d10423bc73b Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 5 Jul 2023 12:07:07 +0800 Subject: [PATCH] [autoparallel] integrate 2-stage solver (#3476) * integrate 2-stage solver * [autoparallel] integrate 2-stage solver * polish --- colossalai/_analyzer/fx/codegen.py | 2 + .../checkpoint/ckpt_solver_base.py | 2 +- .../checkpoint/ckpt_solver_rotor.py | 16 +- .../auto_parallel/passes/meta_info_prop.py | 34 +++- .../tensor_shard/node_handler/node_handler.py | 13 ++ .../solver/strategies_constructor.py | 1 - .../test_2stage_solver_on_gpt.py | 174 ++++++++++++++++++ .../test_2stage_solver_on_resnet.py | 138 ++++++++++++++ .../test_bias_linear_function_node.py | 2 +- .../test_bias_linear_module_node.py | 2 +- .../test_node_handler/test_getitem_handler.py | 2 +- 11 files changed, 367 insertions(+), 19 deletions(-) create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_2stage_solver_on_gpt.py create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_2stage_solver_on_resnet.py diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py index 41d74f2e3..b49c501be 100644 --- a/colossalai/_analyzer/fx/codegen.py +++ b/colossalai/_analyzer/fx/codegen.py @@ -94,6 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0): current_region = None for idx, node in enumerate(node_list): + if 'info' not in node.meta: + continue if len(node.meta['info'].activation_checkpoint) > ckpt_level: act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level] diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py index b388d00ac..1811b9cc0 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -5,12 +5,12 @@ from typing import Any, List import torch from torch.fx import Graph, Node +from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen from colossalai.auto_parallel.passes.runtime_apply_pass import ( runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply, ) -from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen __all___ = ['CheckpointSolverBase'] diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py index 21c3bf0da..f71e8d5d1 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -369,7 +369,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): in_ckpt = False for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'] = [ckpt_idx] + n.meta['info'].activation_checkpoint = [ckpt_idx] ckpt_idx += 1 ckpt_region = [] @@ -377,7 +377,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): elif isinstance(op, ForwardCheck): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'] = [ckpt_idx] + n.meta['info'].activation_checkpoint = [ckpt_idx] ckpt_idx += 1 ckpt_region = [idx] @@ -397,7 +397,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): elif isinstance(op, ForwardEnable): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'].append(ckpt_idx) + n.meta['info'].activation_checkpoint.append(ckpt_idx) ckpt_idx += 1 ckpt_region = [] @@ -405,7 +405,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): elif isinstance(op, ForwardCheck): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'].append(ckpt_idx) + n.meta['info'].activation_checkpoint.append(ckpt_idx) ckpt_idx += 1 ckpt_region = [op.index] @@ -413,7 +413,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): elif isinstance(op, Backward): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'].append(ckpt_idx) + n.meta['info'].activation_checkpoint.append(ckpt_idx) in_recompute = False @@ -433,7 +433,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): ckpt_regions = _find_nested_ckpt_regions(op_list) for (start_idx, end_idx) in ckpt_regions: nested_length = max( - len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1)) + len(op_list[idx].meta['info'].activation_checkpoint) for idx in range(start_idx, end_idx + 1)) for idx in range(start_idx, end_idx + 1): - op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length - - len(op_list[idx].meta['activation_checkpoint'])) + op_list[idx].meta['info'].activation_checkpoint += [None] * ( + nested_length - len(op_list[idx].meta['info'].activation_checkpoint)) diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index 0673b767d..3a0bd83d7 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -68,7 +68,7 @@ class MetaInfoProp: graph_info = GraphInfo() out = _normalize_tuple(getattr(node, '_meta_data', None)) graph_info.fwd_out = list(out) if out[0] is not None else [] - node.meta = {**asdict(graph_info)} + node.meta.update(asdict(graph_info)) @compatibility(is_backward_compatible=False) def get_attr_handler(self, node: Node) -> None: @@ -76,7 +76,7 @@ class MetaInfoProp: Handle the get_attr node. """ graph_info = GraphInfo() - node.meta = {**asdict(graph_info)} + node.meta.update(asdict(graph_info)) @compatibility(is_backward_compatible=False) def output_handler(self, node: Node) -> None: @@ -89,15 +89,36 @@ class MetaInfoProp: if par.meta: output_tensors += par.meta["fwd_out"] graph_info.fwd_in = output_tensors - node.meta = {**asdict(graph_info)} + node.meta.update(asdict(graph_info)) @compatibility(is_backward_compatible=False) def node_handler(self, node: Node) -> None: """ Handle other kind of nodes """ - assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}" graph_info = GraphInfo() + if not hasattr(node, 'best_strategy_info'): + # attach them to graph_info + graph_info.fwd_in = [] + graph_info.fwd_tmp = [] + graph_info.fwd_out = [] + + # fetch other memory informations + graph_info.fwd_mem_tmp = 10 + graph_info.fwd_mem_out = 10 + graph_info.bwd_mem_tmp = 10 + graph_info.bwd_mem_out = 10 + + # fetch flop information + # here we use fwd_time and bwd_time to deal with the case that + # communication cost is a float + graph_info.fwd_time = 10 + graph_info.bwd_time = 10 + node.meta.update(asdict(graph_info)) + # print(node.name, [isinstance(arg, torch.Tensor) for arg in node.args], isinstance(node._meta_data, torch.Tensor)) + return + assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}" + meta_info = node.best_strategy_info meta_info: ShardMetaInfo @@ -124,6 +145,8 @@ class MetaInfoProp: for par in node._input_nodes: # set data_ptr for the input_tensor of current node from the output_tensor of its parent node for tensor in par.meta.get("fwd_out", []): + if not isinstance(tensor, torch.Tensor): + continue tensor: torch.Tensor target_input_tensor = next( (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None) @@ -161,5 +184,4 @@ class MetaInfoProp: compute_cost = meta_info.compute_cost graph_info.fwd_time = compute_cost.fwd graph_info.bwd_time = compute_cost.bwd - - node.meta = {**asdict(graph_info)} + node.meta.update(asdict(graph_info)) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index b4b7b0e79..61563a296 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -157,9 +157,22 @@ class NodeHandler(ABC): Register different sharding strategies for the current node. """ strategy_generators = self.get_strategy_generator() + strategies_info = [] for generator in strategy_generators: strategies = generator.generate() + for strategy in strategies: + shard_metainfo = ShardMetaInfo() + shard_metainfo.compute_cost = strategy.compute_cost + shard_metainfo.memory_cost = strategy.memory_cost + shard_metainfo.fwd_in = [] + if isinstance(self.node._meta_data, torch.Tensor): + shard_metainfo.fwd_out = [self.node._meta_data] + else: + shard_metainfo.fwd_out = self.node._meta_data + shard_metainfo.fwd_buffer = [] + strategies_info.append(shard_metainfo) + # postprocess a strategy # postprocess can produce one strategy or multiple strategies post_processed_strategies_map = map(self.post_process, strategies) diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 044a8ac84..5f4d8de72 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -102,7 +102,6 @@ class StrategiesConstructor: if _check_no_strategy_for_node(node): self.no_strategy_nodes.append(node) - pass # placeholder node elif node.op == 'placeholder': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_2stage_solver_on_gpt.py b/tests/test_auto_parallel/test_tensor_shard/test_2stage_solver_on_gpt.py new file mode 100644 index 000000000..566364efc --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_2stage_solver_on_gpt.py @@ -0,0 +1,174 @@ +from time import time + +import psutil +import pytest +import torch +import transformers + +try: + from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen + NON_CODEGEN = False +except: + NON_CODEGEN = True + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.checkpoint.ckpt_solver_rotor import CheckpointSolverRotor +from colossalai.auto_parallel.passes.comm_metainfo_pass import comm_metainfo_pass +from colossalai.auto_parallel.passes.meta_info_prop import MetaInfoProp +from colossalai.auto_parallel.tensor_shard.initialize import ( + ModuleWrapper, + build_strategy_constructor, + solve_solution, + transform_to_sharded_model, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch, launch_from_torch +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn +from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model + +BATCH_SIZE = 16 +SEQ_LENGTH = 1024 +HIDDEN_DIM = 2048 +NUM_HEADS = 16 +NUM_LAYERS = 2 +VOCAB_SIZE = 50257 +NUM_STEPS = 10 + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def data_gen_mlp(model_cls): + """ + Generate random data for resnet benchmarking + """ + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32) + + if model_cls == GPT2MLP: + input_sample = (hidden_states.to('cuda'),) + meta_input_sample = { + 'hidden_states': hidden_states.to('meta'), + } + elif model_cls in (GPT2Attention, GPT2Block): + attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) + input_sample = ( + hidden_states.to('cuda'), + attention_mask.to('cuda'), + ) + meta_input_sample = { + 'hidden_states': hidden_states.to('meta'), + 'attention_mask': attention_mask.to('meta'), + } + else: + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + input_sample = ( + input_ids.to('cuda'), + attention_mask.to('cuda'), + ) + meta_input_sample = { + 'input_ids': input_ids.to('meta'), + 'attention_mask': attention_mask.to('meta'), + } + return input_sample, meta_input_sample + + +def check_2stage_solver_on_gpt(rank, world_size, port, model_cls): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + logger = get_dist_logger() + config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM) + if model_cls == GPT2MLP: + model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda') + else: + model = model_cls(config=config).to('cuda') + + input_sample, meta_input_sample = data_gen_mlp(model_cls) + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + tracer = ColoTracer(bias_addition_split=True) + + graph = tracer.trace(root=model, meta_args=meta_input_sample) + graph.set_codegen(ActivationCheckpointCodeGen()) + gm = ColoGraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *meta_input_sample.values()) + gm.recompile() + + strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard') + solution = solve_solution(gm, strategies_constructor, memory_budget=-1) + gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh, + strategies_constructor) + comm_metainfo_pass(gm, *sharding_spec_dicts) + + MetaInfoProp(gm).run() + + gm = ModuleWrapper(gm, *sharding_spec_dicts) + ckpt_solver = CheckpointSolverRotor(gm.module.graph, 8 * 1024**3) + gm.module.graph = ckpt_solver.solve() + ckpt_solver.print_sequence() + gm.module.recompile() + print(gm.module) + logger.info("*******************strategy selected*******************", ranks=[0]) + strategies_list = solution + + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + for index, node in enumerate(nodes): + logger.info(node.name, ranks=[0]) + logger.info(node.strategies_vector[strategies_list[index]].name, ranks=[0]) + + optimizer = torch.optim.Adam(gm.parameters(), lr=0.01) + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + torch.cuda.synchronize() + gm.train() + + for n in range(10): + # we just use randomly generated data here + input_sample, _ = data_gen_mlp(model_cls) + mem_stamp0 = torch.cuda.memory_allocated(device='cuda:0') / 1024**2 + optimizer.zero_grad() + start = time() + loss = gm(*input_sample) + loss.sum().backward() + optimizer.step() + # prof.step() + torch.cuda.synchronize() + step_time = time() - start + logger.info(f"===============Round {n}===============", ranks=[0]) + logger.info( + f"Peak Memory: {torch.cuda.max_memory_allocated(device='cuda:0') / 1024**2 - mem_stamp0} MB, Step Time: {step_time:.3f}s", + ranks=[0]) + torch.cuda.synchronize() + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NON_CODEGEN, reason='codegen is not available') +@pytest.mark.dist +@rerun_if_address_is_in_use() +@parameterize('model_cls', [GPT2MLP, GPT2Attention, GPT2Block, GPT2Model]) +def test_2stage_solver_on_gpt(model_cls): + spawn( + check_2stage_solver_on_gpt, + 4, + model_cls=model_cls, + ) + + +if __name__ == '__main__': + test_2stage_solver_on_gpt() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_2stage_solver_on_resnet.py b/tests/test_auto_parallel/test_tensor_shard/test_2stage_solver_on_resnet.py new file mode 100644 index 000000000..11e70f831 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_2stage_solver_on_resnet.py @@ -0,0 +1,138 @@ +from time import time + +import psutil +import pytest +import torch +import torchvision.models as tm + +try: + from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen + NON_CODEGEN = False +except: + NON_CODEGEN = True + +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer +from colossalai.auto_parallel.checkpoint.ckpt_solver_rotor import CheckpointSolverRotor +from colossalai.auto_parallel.passes.comm_metainfo_pass import comm_metainfo_pass +from colossalai.auto_parallel.passes.meta_info_prop import MetaInfoProp +from colossalai.auto_parallel.tensor_shard.initialize import ( + ModuleWrapper, + build_strategy_constructor, + solve_solution, + transform_to_sharded_model, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch, launch_from_torch +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.pytest_wrapper import run_on_environment_flag + +BATCH_SIZE = 256 + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def data_gen_resnet(batch_size, shape): + """ + Generate random data for resnet benchmarking + """ + data = torch.empty(batch_size, *shape, device=torch.cuda.current_device()) + label = torch.empty(batch_size, dtype=torch.long, device=torch.cuda.current_device()).random_(1000) + return data, label + + +def check_2stage_solver_on_resnet(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + logger = get_dist_logger() + model = tm.resnet50().cuda() + + meta_input_sample = { + 'x': torch.randn(BATCH_SIZE * 4, 3, 224, 224, device='meta'), + } + + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + tracer = ColoTracer(bias_addition_split=True, trace_act_ckpt=True) + + graph = tracer.trace(root=model, meta_args=meta_input_sample) + graph.set_codegen(ActivationCheckpointCodeGen()) + gm = ColoGraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *meta_input_sample.values()) + gm.recompile() + + strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard') + solution = solve_solution(gm, strategies_constructor, memory_budget=-1) + gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh, + strategies_constructor) + comm_metainfo_pass(gm, *sharding_spec_dicts) + MetaInfoProp(gm).run() + gm = ModuleWrapper(gm, *sharding_spec_dicts) + + ckpt_solver = CheckpointSolverRotor(gm.module.graph, 8 * 1024**3) + gm.module.graph = ckpt_solver.solve() + ckpt_solver.print_sequence() + + logger.info("*******************strategy selected*******************", ranks=[0]) + strategies_list = solution + + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + for index, node in enumerate(nodes): + logger.info(node.name, ranks=[0]) + logger.info(node.strategies_vector[strategies_list[index]].name, ranks=[0]) + + # build criterion + criterion = torch.nn.CrossEntropyLoss().cuda() + + optimizer = torch.optim.Adam(gm.parameters(), lr=0.01) + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + torch.cuda.synchronize() + model.train() + for n in range(10): + # we just use randomly generated data here + data, label = data_gen_resnet(BATCH_SIZE * 4, (3, 224, 224)) + mem_stamp0 = torch.cuda.memory_allocated(device='cuda:0') / 1024**2 + optimizer.zero_grad() + start = time() + outputs = gm(data) + loss = criterion(outputs, label) + loss.backward() + optimizer.step() + torch.cuda.synchronize() + step_time = time() - start + logger.info(f"===============Round {n}===============", ranks=[0]) + logger.info( + f"Peak Memory: {torch.cuda.max_memory_allocated(device='cuda:0') / 1024**2 - mem_stamp0} MB, Step Time: {step_time:.3f}s", + ranks=[0]) + torch.cuda.synchronize() + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NON_CODEGEN, reason='codegen is not available') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_2stage_solver_on_resnet(): + spawn( + check_2stage_solver_on_resnet, + 4, + ) + + +if __name__ == '__main__': + test_2stage_solver_on_resnet() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py index 800bc11a5..2ba897aef 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py @@ -162,7 +162,7 @@ def check_linear_module_handler(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(): - spawn(check_linear_module_handler) + spawn(check_linear_module_handler, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py index c29a065d1..402c98335 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -151,7 +151,7 @@ def check_linear_module_handler(rank, world_size, port, bias): @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(bias=True): - spawn(check_linear_module_handler, bias=bias) + spawn(check_linear_module_handler, 4, bias=bias) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index a2e0968b1..314cea861 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -100,7 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): # @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))]) @parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) def test_getitem_from_tensor_handler(getitem_index): - spawn(check_getitem_from_tensor_handler, 4) + spawn(check_getitem_from_tensor_handler, 4, getitem_index=getitem_index) class GetItemFromTupleModel(nn.Module):