mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
[autoparallel] gpt2lp runtimee test (#2113)
This commit is contained in:
parent
9214d1fe28
commit
cd0af9f7f6
@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||||||
OperationDataType,
|
OperationDataType,
|
||||||
ShardingStrategy,
|
ShardingStrategy,
|
||||||
)
|
)
|
||||||
|
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.comm_spec import _all_reduce
|
from colossalai.tensor.comm_spec import _all_reduce
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
@ -19,13 +20,23 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
|||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
|
||||||
def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
def _solution_annotatation(gm: torch.fx.GraphModule,
|
||||||
|
solution: List[int],
|
||||||
|
strategies_constructor: StrategiesConstructor = None):
|
||||||
"""
|
"""
|
||||||
This method is used to stick the solution strategy to the nodes and add the information
|
This method is used to stick the solution strategy to the nodes and add the information
|
||||||
required in runtime into graph as placeholder nodes.
|
required in runtime into graph as placeholder nodes.
|
||||||
"""
|
"""
|
||||||
mod_graph = gm.graph
|
mod_graph = gm.graph
|
||||||
|
# TODO: In future PR, strategies_constructor should be a required argument,
|
||||||
|
# instead of optional argument. This is because we don't need to consider nodes with
|
||||||
|
# no strategy in runtime preparation pass.
|
||||||
|
if strategies_constructor is not None:
|
||||||
|
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||||
|
no_strategy_nodes = strategies_constructor.no_strategy_nodes
|
||||||
|
else:
|
||||||
nodes = tuple(mod_graph.nodes)
|
nodes = tuple(mod_graph.nodes)
|
||||||
|
no_strategy_nodes = []
|
||||||
|
|
||||||
# the dict to get origin sharding spec of node
|
# the dict to get origin sharding spec of node
|
||||||
origin_node_sharding_spec_dict = {}
|
origin_node_sharding_spec_dict = {}
|
||||||
@ -44,6 +55,9 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
|||||||
for index, node in enumerate(nodes):
|
for index, node in enumerate(nodes):
|
||||||
target_sharding_specs = []
|
target_sharding_specs = []
|
||||||
for user_node in node.strategies_vector.successor_nodes:
|
for user_node in node.strategies_vector.successor_nodes:
|
||||||
|
if user_node in no_strategy_nodes:
|
||||||
|
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||||
|
else:
|
||||||
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||||
target_sharding_specs.append(target_sharding_spec)
|
target_sharding_specs.append(target_sharding_spec)
|
||||||
sharding_spec_convert_dict[index] = target_sharding_specs
|
sharding_spec_convert_dict[index] = target_sharding_specs
|
||||||
@ -136,13 +150,17 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||||||
new_args.append(arg)
|
new_args.append(arg)
|
||||||
|
|
||||||
for dim, shard_dims in output_dim_partition_dict.items():
|
for dim, shard_dims in output_dim_partition_dict.items():
|
||||||
# we will skip the dim with -1 value
|
|
||||||
if new_args[dim + 1] == -1:
|
|
||||||
continue
|
|
||||||
total_shard_size = 1
|
total_shard_size = 1
|
||||||
for shard_dim in shard_dims:
|
for shard_dim in shard_dims:
|
||||||
total_shard_size *= device_mesh.shape[shard_dim]
|
total_shard_size *= device_mesh.shape[shard_dim]
|
||||||
|
# There are two ways to use torch.view:
|
||||||
|
# 1. torch.view(input, *shape)
|
||||||
|
# 2. torch.view(input, shape)
|
||||||
|
if isinstance(new_args[1], int):
|
||||||
new_args[dim + 1] //= total_shard_size
|
new_args[dim + 1] //= total_shard_size
|
||||||
|
else:
|
||||||
|
new_args[1] = list(new_args[1])
|
||||||
|
new_args[1][dim] //= total_shard_size
|
||||||
node.args = tuple(new_args)
|
node.args = tuple(new_args)
|
||||||
|
|
||||||
elif node.op == 'call_function':
|
elif node.op == 'call_function':
|
||||||
@ -193,12 +211,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||||
param_sharded = torch.nn.Parameter(
|
# we could use .data here, because all the operations just happen before the real training
|
||||||
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
# loop, so we don't need to track these operations in the autograd graph.
|
||||||
target_sharding_spec).detach().clone())
|
param.data = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||||
else:
|
param.data, param.sharding_spec, target_sharding_spec).detach().clone()
|
||||||
param_sharded = param
|
|
||||||
setattr(target_module, name, param_sharded)
|
setattr(target_module, name, param)
|
||||||
comm_actions = node.best_strategy.communication_actions
|
comm_actions = node.best_strategy.communication_actions
|
||||||
for operation_data, comm_action in comm_actions.items():
|
for operation_data, comm_action in comm_actions.items():
|
||||||
comm_spec_to_use = comm_action.comm_spec
|
comm_spec_to_use = comm_action.comm_spec
|
||||||
@ -212,7 +230,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||||||
|
|
||||||
param.register_hook(hook_fn)
|
param.register_hook(hook_fn)
|
||||||
|
|
||||||
wrapper(param_sharded, comm_spec_to_use)
|
wrapper(param, comm_spec_to_use)
|
||||||
|
|
||||||
sharded_buffer_dict = {}
|
sharded_buffer_dict = {}
|
||||||
# apply the sharding spec of buffers
|
# apply the sharding spec of buffers
|
||||||
@ -242,12 +260,13 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||||||
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
|
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
|
||||||
setattr(target, 'sharding_spec', origin_sharding_spec)
|
setattr(target, 'sharding_spec', origin_sharding_spec)
|
||||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||||
target_sharded = torch.nn.Parameter(
|
# we could use .data here, because all the operations just happen before the real training
|
||||||
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
|
# loop, so we don't need to track these operations in the autograd graph.
|
||||||
target_sharding_spec).detach().clone())
|
target.data = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||||
else:
|
target.data, target.sharding_spec, target_sharding_spec).detach().clone()
|
||||||
target_sharded = target
|
|
||||||
setattr(target_module, atoms[-1], target_sharded)
|
assert hasattr(target_module, atoms[-1])
|
||||||
|
setattr(target_module, atoms[-1], target)
|
||||||
|
|
||||||
comm_actions = node.best_strategy.communication_actions
|
comm_actions = node.best_strategy.communication_actions
|
||||||
for operation_data, comm_action in comm_actions.items():
|
for operation_data, comm_action in comm_actions.items():
|
||||||
@ -262,7 +281,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||||||
|
|
||||||
param.register_hook(hook_fn)
|
param.register_hook(hook_fn)
|
||||||
|
|
||||||
wrapper(target_sharded, comm_spec_to_use)
|
wrapper(target, comm_spec_to_use)
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
@ -273,9 +292,12 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh):
|
def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
||||||
|
solution: List[int],
|
||||||
|
device_mesh: DeviceMesh,
|
||||||
|
strategies_constructor: StrategiesConstructor = None):
|
||||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
||||||
gm, solution)
|
gm, solution, strategies_constructor)
|
||||||
gm = _node_args_converting(gm, device_mesh)
|
gm = _node_args_converting(gm, device_mesh)
|
||||||
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
||||||
# gm = implicit_comm_action_apply(gm)
|
# gm = implicit_comm_action_apply(gm)
|
||||||
|
@ -41,6 +41,7 @@ class StrategiesConstructor:
|
|||||||
self.leaf_strategies = []
|
self.leaf_strategies = []
|
||||||
self.strategy_map = {}
|
self.strategy_map = {}
|
||||||
self.solver_options = solver_options
|
self.solver_options = solver_options
|
||||||
|
self.no_strategy_nodes = []
|
||||||
|
|
||||||
def remove_duplicated_strategy(self, strategies_vector):
|
def remove_duplicated_strategy(self, strategies_vector):
|
||||||
'''
|
'''
|
||||||
@ -78,12 +79,11 @@ class StrategiesConstructor:
|
|||||||
|
|
||||||
return _check_no_strategy_for_data(node._meta_data)
|
return _check_no_strategy_for_data(node._meta_data)
|
||||||
|
|
||||||
no_strategy_node = []
|
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
strategies_vector = StrategiesVector(node)
|
strategies_vector = StrategiesVector(node)
|
||||||
|
|
||||||
if _check_no_strategy_for_node(node):
|
if _check_no_strategy_for_node(node):
|
||||||
no_strategy_node.append(node)
|
self.no_strategy_nodes.append(node)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# placeholder node
|
# placeholder node
|
||||||
|
@ -0,0 +1,214 @@
|
|||||||
|
import copy
|
||||||
|
import random
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
import transformers
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
||||||
|
from transformers.pytorch_utils import Conv1D
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||||
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||||
|
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||||
|
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||||
|
CostGraph,
|
||||||
|
GraphAnalyser,
|
||||||
|
Solver,
|
||||||
|
SolverOptions,
|
||||||
|
StrategiesConstructor,
|
||||||
|
)
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global
|
||||||
|
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
BATCH_SIZE = 1
|
||||||
|
SEQ_LENGTH = 32
|
||||||
|
HIDDEN_DIM = 768
|
||||||
|
|
||||||
|
seed = 128
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2MLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, intermediate_size, config):
|
||||||
|
super().__init__()
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
self.c_fc = Conv1D(intermediate_size, embed_dim)
|
||||||
|
self.c_proj = Conv1D(embed_dim, intermediate_size)
|
||||||
|
self.act = ACT2FN[config.activation_function]
|
||||||
|
# We temporarily banned the Dropout layer because the rng state need
|
||||||
|
# to process to get the correct result.
|
||||||
|
# self.dropout = nn.Dropout(config.resid_pdrop)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
||||||
|
hidden_states = self.c_fc(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.c_proj(hidden_states)
|
||||||
|
# TODO: the rng state need to be fixed for distributed runtime
|
||||||
|
# hidden_states = self.dropout(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def check_mlp_layer(rank, model_cls, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
|
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
|
||||||
|
model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
|
||||||
|
input = torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('cuda')
|
||||||
|
test_model = copy.deepcopy(model)
|
||||||
|
test_input = copy.deepcopy(input)
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
|
||||||
|
input_sample = {
|
||||||
|
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
print(graph)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
print(gm)
|
||||||
|
graph_analyser = GraphAnalyser(gm)
|
||||||
|
liveness_list = graph_analyser.liveness_analysis()
|
||||||
|
solver_options = SolverOptions()
|
||||||
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
|
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||||
|
cost_graph.simplify_graph()
|
||||||
|
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
|
||||||
|
ret = solver.call_solver_serialized_args()
|
||||||
|
|
||||||
|
solution = list(ret[0])
|
||||||
|
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||||
|
gm, solution, device_mesh, strategies_constructor)
|
||||||
|
gm = runtime_apply_pass(gm)
|
||||||
|
gm.recompile()
|
||||||
|
cuda_rng_state = torch.cuda.get_rng_state()
|
||||||
|
cpu_rng_state = torch.get_rng_state()
|
||||||
|
origin_output = test_model(test_input)
|
||||||
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
|
torch.set_rng_state(cpu_rng_state)
|
||||||
|
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||||
|
assert_close(output, origin_output, rtol=1e-03, atol=1e-04)
|
||||||
|
|
||||||
|
#*******************backward starting*******************
|
||||||
|
cuda_rng_state = torch.cuda.get_rng_state()
|
||||||
|
output.sum().backward()
|
||||||
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
|
origin_output.sum().backward()
|
||||||
|
origin_param_dict = dict(test_model.named_parameters())
|
||||||
|
if rank == 0:
|
||||||
|
print("*******************backward starting*******************")
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
param_grad = param.grad
|
||||||
|
origin_param_grad = origin_param_dict[name].grad
|
||||||
|
origin_param_size = origin_param_grad.shape[-1]
|
||||||
|
print(name, param_grad, origin_param_grad)
|
||||||
|
if name == 'c_fc.bias':
|
||||||
|
assert_close_loose(param_grad,
|
||||||
|
origin_param_grad.narrow(0, 0, origin_param_size // 2),
|
||||||
|
rtol=1e-03,
|
||||||
|
atol=1e-03)
|
||||||
|
else:
|
||||||
|
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
|
||||||
|
print("*******************backward finished*******************")
|
||||||
|
if rank == 1:
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
param_grad = param.grad
|
||||||
|
origin_param_grad = origin_param_dict[name].grad
|
||||||
|
origin_param_size = origin_param_grad.shape[-1]
|
||||||
|
if name == 'c_fc.bias':
|
||||||
|
assert_close_loose(param_grad,
|
||||||
|
origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2),
|
||||||
|
rtol=1e-03,
|
||||||
|
atol=1e-03)
|
||||||
|
else:
|
||||||
|
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
|
||||||
|
if rank == 2:
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
param_grad = param.grad
|
||||||
|
origin_param_grad = origin_param_dict[name].grad
|
||||||
|
origin_param_size = origin_param_grad.shape[-1]
|
||||||
|
if name == 'c_fc.bias':
|
||||||
|
assert_close_loose(param_grad,
|
||||||
|
origin_param_grad.narrow(0, 0, origin_param_size // 2),
|
||||||
|
rtol=1e-03,
|
||||||
|
atol=1e-03)
|
||||||
|
else:
|
||||||
|
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
|
||||||
|
if rank == 3:
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
param_grad = param.grad
|
||||||
|
origin_param_grad = origin_param_dict[name].grad
|
||||||
|
origin_param_size = origin_param_grad.shape[-1]
|
||||||
|
if name == 'c_fc.bias':
|
||||||
|
assert_close_loose(param_grad,
|
||||||
|
origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2),
|
||||||
|
rtol=1e-03,
|
||||||
|
atol=1e-03)
|
||||||
|
else:
|
||||||
|
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
|
||||||
|
|
||||||
|
#*******************backward finished*******************
|
||||||
|
|
||||||
|
#*******************strategy selected*******************
|
||||||
|
if rank == 0:
|
||||||
|
print("*******************strategy selected*******************")
|
||||||
|
strategies_list = solver.last_s_val
|
||||||
|
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||||
|
computation_cost = 0
|
||||||
|
communication_cost = 0
|
||||||
|
memory_cost = 0
|
||||||
|
for index, node in enumerate(nodes):
|
||||||
|
print(node.name, node.strategies_vector[strategies_list[index]].name)
|
||||||
|
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
|
||||||
|
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
|
||||||
|
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
|
||||||
|
if isinstance(node_memory_cost, tuple):
|
||||||
|
node_memory_cost = node_memory_cost[0]
|
||||||
|
memory_cost += node_memory_cost.activation + node_memory_cost.parameter
|
||||||
|
|
||||||
|
print(f'computation cost is {computation_cost}')
|
||||||
|
print(f'communication cost is {communication_cost}')
|
||||||
|
print(f'memory cost is {memory_cost}')
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@pytest.mark.dist
|
||||||
|
@parameterize('model_cls', [GPT2MLP])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_mlp_layer(model_cls):
|
||||||
|
world_size = 4
|
||||||
|
run_func = partial(check_mlp_layer, model_cls=model_cls, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_mlp_layer()
|
Loading…
Reference in New Issue
Block a user