mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[fx] tested the complete workflow for auto-parallel (#1336)
* [fx] tested the complete workflow for auto-parallel * polish code * polish code * polish code
This commit is contained in:
@@ -1,9 +1,16 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import operator
|
||||
import colossalai
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.tensor.distspec import shard
|
||||
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
ELEMENTWISE_FUNC_OP = [
|
||||
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
|
||||
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
|
||||
]
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.AvgPool1d, torch.nn.AvgPool2d]
|
||||
ELEMENTWISE_FUNC_OP = [torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d, torch.nn.functional.avg_pool1d, torch.nn.functional.avg_pool2d, torch.nn.functional.avg_pool3d, torch.nn.functional.max_pool1d, torch.nn.functional.max_pool2d, torch.nn.functional.max_pool3d]
|
||||
|
||||
def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
|
||||
"""weight_split
|
||||
@@ -21,6 +28,8 @@ def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: boo
|
||||
else:
|
||||
setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs"))
|
||||
return weight
|
||||
|
||||
|
||||
def column_shard_linear_pass(gm: torch.fx.GraphModule):
|
||||
# Split all the linear module with column shard. Currently for testing only.
|
||||
mod_graph = gm.graph
|
||||
@@ -48,43 +57,95 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
def transform_mlp_pass(gm: torch.fx.GraphModule):
|
||||
|
||||
def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup):
|
||||
"""
|
||||
This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
|
||||
"""
|
||||
#TODO: Needs to handle special cases, like x = linear(x) + linear(x)
|
||||
mod_graph = gm.graph
|
||||
col_shard = True
|
||||
element_op = []
|
||||
all_linear_name = []
|
||||
linear_name = []
|
||||
# Get the name of element wise module(torch.nn.ReLU)
|
||||
# Get the name of all the linear modules and repeated linear modules
|
||||
for name, func in gm.named_children():
|
||||
if not isinstance(func, torch.nn.Linear):
|
||||
for i in ELEMENTWISE_MODULE_OP:
|
||||
if isinstance(func, i):
|
||||
element_op.append(name)
|
||||
break
|
||||
else:
|
||||
if name in all_linear_name:
|
||||
if name in linear_name:
|
||||
linear_name.remove(name)
|
||||
else:
|
||||
all_linear_name.append(name)
|
||||
linear_name.append(name)
|
||||
# If the linear modules is called multiple times, set the dist spec as col shard
|
||||
# If the module is element wise or the function/method is element wise, remains col_shard
|
||||
for node in mod_graph.nodes:
|
||||
if node.target in linear_name:
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
dim = 0 if col_shard else -1
|
||||
target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=False)
|
||||
col_shard = not col_shard
|
||||
elif node.target in all_linear_name:
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
dim = 0 if col_shard else -1
|
||||
target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=True)
|
||||
col_shard = not col_shard
|
||||
else:
|
||||
if node.target not in element_op and all(node.target != i for i in ELEMENTWISE_FUNC_OP):
|
||||
col_shard = True
|
||||
gm.recompile()
|
||||
return gm
|
||||
graph = graph_module.graph
|
||||
world_size = process_group.world_size()
|
||||
|
||||
def _traverse_and_annotate(node, start_tracking, annotation_record, world_size):
|
||||
# traverse the graph to look for consecutive linear layers
|
||||
is_linear_module = False
|
||||
|
||||
if node.op == 'call_module':
|
||||
# look for the linear layer
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
if isinstance(module, nn.Linear):
|
||||
is_linear_module = True
|
||||
if start_tracking:
|
||||
# when start_tracking = True
|
||||
# it means the first linear has been found and the current module
|
||||
# is the second linear
|
||||
# set the current linear module to be row-sharded
|
||||
annotation_record['row'] = module
|
||||
|
||||
for shard_type, module in annotation_record.items():
|
||||
# add row sharding spec
|
||||
if shard_type == 'row':
|
||||
dist_spec = shard(dims=[-1], num_partitions=[world_size])
|
||||
comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
setattr(module.weight, 'pg', process_group)
|
||||
setattr(module.weight, 'dist_spec', dist_spec)
|
||||
setattr(module.weight, 'comp_spec', comp_spec)
|
||||
elif shard_type == 'col':
|
||||
weight_dist_spec = shard(dims=[0], num_partitions=[world_size])
|
||||
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
weight_comp_spec.output_replicate = False
|
||||
setattr(module.weight, 'pg', process_group)
|
||||
setattr(module.weight, 'dist_spec', weight_dist_spec)
|
||||
setattr(module.weight, 'comp_spec', weight_comp_spec)
|
||||
|
||||
if module.bias is not None:
|
||||
bias_dist_spec = shard(dims=[0], num_partitions=[world_size])
|
||||
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
bias_comp_spec.output_replicate = False
|
||||
setattr(module.bias, 'pg', process_group)
|
||||
setattr(module.bias, 'dist_spec', bias_dist_spec)
|
||||
setattr(module.bias, 'comp_spec', bias_comp_spec)
|
||||
start_tracking = False
|
||||
annotation_record.clear()
|
||||
else:
|
||||
# when start tracking = False
|
||||
# it means the current layer is the first linear
|
||||
# set the linear layer to be col-sharded
|
||||
start_tracking = True
|
||||
annotation_record['col'] = module
|
||||
|
||||
if start_tracking and not is_linear_module:
|
||||
# check against the white list
|
||||
# if non-element wise op is found, we reset the tracking
|
||||
if node.op == 'call_module':
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
if module.__class__ not in ELEMENTWISE_MODULE_OP:
|
||||
start_tracking = False
|
||||
elif node.op == 'call_function' or node.op == 'call_method':
|
||||
if node.target not in ELEMENTWISE_FUNC_OP:
|
||||
start_tracking = False
|
||||
elif len(node.users.keys()) > 1:
|
||||
start_tracking = False
|
||||
|
||||
if not start_tracking:
|
||||
annotation_record.clear()
|
||||
|
||||
# stop tracking for consecutive linear when branch is found
|
||||
# e.g.
|
||||
# out1 = self.linear1(x)
|
||||
# out2 = self.linear2(x)
|
||||
# return out1+out2
|
||||
next_nodes = list(node.users.keys())
|
||||
if len(next_nodes) > 1:
|
||||
start_tracking = False
|
||||
annotation_record.clear()
|
||||
|
||||
# traverse
|
||||
for node in next_nodes:
|
||||
_traverse_and_annotate(node, start_tracking, annotation_record, world_size)
|
||||
|
||||
placeholder_node = list(graph.nodes)[0]
|
||||
annotate_record = {}
|
||||
_traverse_and_annotate(placeholder_node, False, annotate_record, world_size)
|
||||
|
||||
return graph_module
|
||||
|
Reference in New Issue
Block a user