mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 22:19:47 +00:00
[fx] tested the complete workflow for auto-parallel (#1336)
* [fx] tested the complete workflow for auto-parallel * polish code * polish code * polish code
This commit is contained in:
parent
4631fef8a0
commit
2cc1175c76
@ -1,9 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import operator
|
import operator
|
||||||
import colossalai
|
from colossalai.tensor import ProcessGroup
|
||||||
|
from colossalai.tensor.distspec import shard
|
||||||
|
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
|
||||||
|
|
||||||
|
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||||
|
ELEMENTWISE_FUNC_OP = [
|
||||||
|
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
|
||||||
|
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
|
||||||
|
]
|
||||||
|
|
||||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.AvgPool1d, torch.nn.AvgPool2d]
|
|
||||||
ELEMENTWISE_FUNC_OP = [torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d, torch.nn.functional.avg_pool1d, torch.nn.functional.avg_pool2d, torch.nn.functional.avg_pool3d, torch.nn.functional.max_pool1d, torch.nn.functional.max_pool2d, torch.nn.functional.max_pool3d]
|
|
||||||
|
|
||||||
def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
|
def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
|
||||||
"""weight_split
|
"""weight_split
|
||||||
@ -21,6 +28,8 @@ def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: boo
|
|||||||
else:
|
else:
|
||||||
setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs"))
|
setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs"))
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|
||||||
def column_shard_linear_pass(gm: torch.fx.GraphModule):
|
def column_shard_linear_pass(gm: torch.fx.GraphModule):
|
||||||
# Split all the linear module with column shard. Currently for testing only.
|
# Split all the linear module with column shard. Currently for testing only.
|
||||||
mod_graph = gm.graph
|
mod_graph = gm.graph
|
||||||
@ -48,43 +57,95 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
|
|||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
def transform_mlp_pass(gm: torch.fx.GraphModule):
|
|
||||||
|
def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup):
|
||||||
|
"""
|
||||||
|
This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
|
||||||
|
"""
|
||||||
#TODO: Needs to handle special cases, like x = linear(x) + linear(x)
|
#TODO: Needs to handle special cases, like x = linear(x) + linear(x)
|
||||||
mod_graph = gm.graph
|
graph = graph_module.graph
|
||||||
col_shard = True
|
world_size = process_group.world_size()
|
||||||
element_op = []
|
|
||||||
all_linear_name = []
|
def _traverse_and_annotate(node, start_tracking, annotation_record, world_size):
|
||||||
linear_name = []
|
# traverse the graph to look for consecutive linear layers
|
||||||
# Get the name of element wise module(torch.nn.ReLU)
|
is_linear_module = False
|
||||||
# Get the name of all the linear modules and repeated linear modules
|
|
||||||
for name, func in gm.named_children():
|
if node.op == 'call_module':
|
||||||
if not isinstance(func, torch.nn.Linear):
|
# look for the linear layer
|
||||||
for i in ELEMENTWISE_MODULE_OP:
|
module = node.graph.owning_module.get_submodule(node.target)
|
||||||
if isinstance(func, i):
|
if isinstance(module, nn.Linear):
|
||||||
element_op.append(name)
|
is_linear_module = True
|
||||||
break
|
if start_tracking:
|
||||||
|
# when start_tracking = True
|
||||||
|
# it means the first linear has been found and the current module
|
||||||
|
# is the second linear
|
||||||
|
# set the current linear module to be row-sharded
|
||||||
|
annotation_record['row'] = module
|
||||||
|
|
||||||
|
for shard_type, module in annotation_record.items():
|
||||||
|
# add row sharding spec
|
||||||
|
if shard_type == 'row':
|
||||||
|
dist_spec = shard(dims=[-1], num_partitions=[world_size])
|
||||||
|
comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||||
|
setattr(module.weight, 'pg', process_group)
|
||||||
|
setattr(module.weight, 'dist_spec', dist_spec)
|
||||||
|
setattr(module.weight, 'comp_spec', comp_spec)
|
||||||
|
elif shard_type == 'col':
|
||||||
|
weight_dist_spec = shard(dims=[0], num_partitions=[world_size])
|
||||||
|
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||||
|
weight_comp_spec.output_replicate = False
|
||||||
|
setattr(module.weight, 'pg', process_group)
|
||||||
|
setattr(module.weight, 'dist_spec', weight_dist_spec)
|
||||||
|
setattr(module.weight, 'comp_spec', weight_comp_spec)
|
||||||
|
|
||||||
|
if module.bias is not None:
|
||||||
|
bias_dist_spec = shard(dims=[0], num_partitions=[world_size])
|
||||||
|
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||||
|
bias_comp_spec.output_replicate = False
|
||||||
|
setattr(module.bias, 'pg', process_group)
|
||||||
|
setattr(module.bias, 'dist_spec', bias_dist_spec)
|
||||||
|
setattr(module.bias, 'comp_spec', bias_comp_spec)
|
||||||
|
start_tracking = False
|
||||||
|
annotation_record.clear()
|
||||||
else:
|
else:
|
||||||
if name in all_linear_name:
|
# when start tracking = False
|
||||||
if name in linear_name:
|
# it means the current layer is the first linear
|
||||||
linear_name.remove(name)
|
# set the linear layer to be col-sharded
|
||||||
else:
|
start_tracking = True
|
||||||
all_linear_name.append(name)
|
annotation_record['col'] = module
|
||||||
linear_name.append(name)
|
|
||||||
# If the linear modules is called multiple times, set the dist spec as col shard
|
if start_tracking and not is_linear_module:
|
||||||
# If the module is element wise or the function/method is element wise, remains col_shard
|
# check against the white list
|
||||||
for node in mod_graph.nodes:
|
# if non-element wise op is found, we reset the tracking
|
||||||
if node.target in linear_name:
|
if node.op == 'call_module':
|
||||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
module = node.graph.owning_module.get_submodule(node.target)
|
||||||
dim = 0 if col_shard else -1
|
if module.__class__ not in ELEMENTWISE_MODULE_OP:
|
||||||
target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=False)
|
start_tracking = False
|
||||||
col_shard = not col_shard
|
elif node.op == 'call_function' or node.op == 'call_method':
|
||||||
elif node.target in all_linear_name:
|
if node.target not in ELEMENTWISE_FUNC_OP:
|
||||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
start_tracking = False
|
||||||
dim = 0 if col_shard else -1
|
elif len(node.users.keys()) > 1:
|
||||||
target_module.weight = weight_split(target_module.weight, dim=dim, col_normal=True)
|
start_tracking = False
|
||||||
col_shard = not col_shard
|
|
||||||
else:
|
if not start_tracking:
|
||||||
if node.target not in element_op and all(node.target != i for i in ELEMENTWISE_FUNC_OP):
|
annotation_record.clear()
|
||||||
col_shard = True
|
|
||||||
gm.recompile()
|
# stop tracking for consecutive linear when branch is found
|
||||||
return gm
|
# e.g.
|
||||||
|
# out1 = self.linear1(x)
|
||||||
|
# out2 = self.linear2(x)
|
||||||
|
# return out1+out2
|
||||||
|
next_nodes = list(node.users.keys())
|
||||||
|
if len(next_nodes) > 1:
|
||||||
|
start_tracking = False
|
||||||
|
annotation_record.clear()
|
||||||
|
|
||||||
|
# traverse
|
||||||
|
for node in next_nodes:
|
||||||
|
_traverse_and_annotate(node, start_tracking, annotation_record, world_size)
|
||||||
|
|
||||||
|
placeholder_node = list(graph.nodes)[0]
|
||||||
|
annotate_record = {}
|
||||||
|
_traverse_and_annotate(placeholder_node, False, annotate_record, world_size)
|
||||||
|
|
||||||
|
return graph_module
|
||||||
|
@ -175,7 +175,7 @@ class LazyInitContext():
|
|||||||
self._unpatch_nn_init_funcs()
|
self._unpatch_nn_init_funcs()
|
||||||
self._unpatch_torch_tensor_funcs()
|
self._unpatch_torch_tensor_funcs()
|
||||||
|
|
||||||
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
|
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'):
|
||||||
"""
|
"""
|
||||||
Initialize the weights of the meta-tensor model.
|
Initialize the weights of the meta-tensor model.
|
||||||
|
|
||||||
@ -205,6 +205,7 @@ class LazyInitContext():
|
|||||||
# get sharding spec
|
# get sharding spec
|
||||||
dist_spec = getattr(tensor, 'dist_spec', None)
|
dist_spec = getattr(tensor, 'dist_spec', None)
|
||||||
pg = getattr(tensor, 'pg', None)
|
pg = getattr(tensor, 'pg', None)
|
||||||
|
comp_spec = getattr(tensor, 'comp_spec', None)
|
||||||
|
|
||||||
# convert the tensor from meta to materialized one
|
# convert the tensor from meta to materialized one
|
||||||
if tensor.is_meta:
|
if tensor.is_meta:
|
||||||
@ -224,14 +225,15 @@ class LazyInitContext():
|
|||||||
else:
|
else:
|
||||||
tensor = ColoTensor.from_torch_tensor(tensor)
|
tensor = ColoTensor.from_torch_tensor(tensor)
|
||||||
|
|
||||||
# apply sharding
|
|
||||||
if dist_spec:
|
|
||||||
tensor = tensor.redistribute(dist_spec=dist_spec, pg=pg)
|
|
||||||
|
|
||||||
# override the original tensor
|
# override the original tensor
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
setattr(module, name, tensor)
|
setattr(module, name, tensor)
|
||||||
|
|
||||||
|
# apply sharding
|
||||||
|
if dist_spec:
|
||||||
|
tensor.process_group = pg
|
||||||
|
tensor.set_tensor_spec(dist_spec, comp_spec)
|
||||||
|
|
||||||
_init_recursively(model)
|
_init_recursively(model)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
77
tests/test_fx/test_complete_workflow.py
Normal file
77
tests/test_fx/test_complete_workflow.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import colossalai
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import pytest
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from functools import partial
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
from colossalai.utils.model.lazy_init_context import LazyInitContext
|
||||||
|
from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.tensor import ProcessGroup
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = torch.nn.Linear(dim, dim)
|
||||||
|
self.linear2 = torch.nn.Linear(dim, dim)
|
||||||
|
self.dropout = torch.nn.Dropout(0)
|
||||||
|
self.relu = torch.nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear1(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def run_workflow(world_size):
|
||||||
|
# initailization
|
||||||
|
with LazyInitContext() as ctx:
|
||||||
|
model = MLP(16)
|
||||||
|
|
||||||
|
# tracing
|
||||||
|
tracer = ColoTracer()
|
||||||
|
graph = tracer.trace(model)
|
||||||
|
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
|
||||||
|
# annotate
|
||||||
|
annotated_gm = transformer_mlp_pass(gm, process_group=ProcessGroup())
|
||||||
|
annotated_gm.recompile()
|
||||||
|
|
||||||
|
# materialization and sharding
|
||||||
|
ctx.lazy_init_parameters(annotated_gm)
|
||||||
|
|
||||||
|
# # check sharding
|
||||||
|
assert list(model.linear1.weight.shape) == [16 // world_size, 16]
|
||||||
|
assert list(model.linear1.bias.shape) == [16 // world_size]
|
||||||
|
assert list(model.linear2.weight.shape) == [16, 16 // world_size]
|
||||||
|
|
||||||
|
# test forward to make sure that IR transform will produce the same results
|
||||||
|
# like how ColoTensor would do it normally
|
||||||
|
data = torch.rand(4, 16)
|
||||||
|
non_fx_out = model(data)
|
||||||
|
fx_out = annotated_gm(data)
|
||||||
|
assert torch.equal(non_fx_out, fx_out)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_workflow(world_size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [1, 2])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_complete_workflow(world_size):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_complete_workflow(2)
|
@ -1,59 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import pytest
|
|
||||||
import colossalai
|
|
||||||
from colossalai.fx import ColoTracer
|
|
||||||
from colossalai.fx.passes.shard_1d_pass import transform_mlp_pass
|
|
||||||
CONFIG = dict(parallel=dict(tensor=dict(size=2, mode='1d')))
|
|
||||||
|
|
||||||
class MLP(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, dim: int):
|
|
||||||
super().__init__()
|
|
||||||
self.linear1 = torch.nn.Linear(dim, dim)
|
|
||||||
self.linear2 = torch.nn.Linear(dim, dim)
|
|
||||||
self.linear3 = torch.nn.Linear(dim, dim)
|
|
||||||
self.linear4 = torch.nn.Linear(dim, dim)
|
|
||||||
self.dropout = torch.nn.Dropout()
|
|
||||||
self.relu = torch.nn.ReLU()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.relu(self.linear1(x))
|
|
||||||
x = self.dropout(self.relu(self.linear2(x)))
|
|
||||||
x = self.linear3(x)
|
|
||||||
x = torch.nn.functional.relu(self.linear4(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
def test_out_acc():
|
|
||||||
model = MLP(16).cuda()
|
|
||||||
model.eval()
|
|
||||||
input_tensor = torch.rand(2, 16).cuda()
|
|
||||||
output = model(input_tensor)
|
|
||||||
tracer = ColoTracer()
|
|
||||||
graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")})
|
|
||||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
|
||||||
splitted_gm = transform_mlp_pass(gm)
|
|
||||||
new_output = splitted_gm(input_tensor)
|
|
||||||
assert output.equal(new_output)
|
|
||||||
|
|
||||||
def test_linear_acc():
|
|
||||||
input_tensor = torch.rand(2, 16).cuda()
|
|
||||||
model = MLP(16).cuda()
|
|
||||||
tracer = ColoTracer()
|
|
||||||
graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")})
|
|
||||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
|
||||||
splitted_gm = transform_mlp_pass(gm)
|
|
||||||
col_shard = True
|
|
||||||
for node in splitted_gm.graph.nodes:
|
|
||||||
if node.op == "call_module" and isinstance(node.graph.owning_module.get_submodule(node.target), torch.nn.Linear):
|
|
||||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
|
||||||
dim = 0 if col_shard else -1
|
|
||||||
assert target_module.weight.fx_attr == (dim, "SHARD", "TP", "col_needs_many_outputs")
|
|
||||||
col_shard = not col_shard
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
torch.manual_seed(1)
|
|
||||||
torch.cuda.manual_seed(1)
|
|
||||||
# colossalai.launch_from_torch(config=CONFIG)
|
|
||||||
test_out_acc()
|
|
||||||
test_linear_acc()
|
|
Loading…
Reference in New Issue
Block a user