[autoparallel] refactor runtime pass (#2644)

* [autoparallel] refactor runtime pass

* add unit test

* polish
This commit is contained in:
YuliangLiu0306
2023-02-15 10:36:19 +08:00
committed by GitHub
parent 89f8975fb8
commit cb2c6a2415
5 changed files with 352 additions and 214 deletions

View File

@@ -0,0 +1,54 @@
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
class TestModule(torch.nn.Module):
def forward(self, x):
x = x.view(4, 4, 2)
return x
def insert_narrow(gm, x_node):
graph = gm.graph
with graph.inserting_after(x_node):
shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={})
view_node = list(x_node.users.keys())[0]
new_args = list(view_node.args)
new_args[0] = shard_node
view_node.args = tuple(new_args)
return gm
def test_node_args_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
meta_args = {'x': torch.rand(4, 8).to('meta')}
input = torch.rand(4, 8)
tracer = ColoTracer()
graph = tracer.trace(root=model, meta_args=meta_args)
x_node = list(graph.nodes)[0]
view_node = list(graph.nodes)[1]
sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
setattr(x_node, 'sharding_spec', sharding_spec)
setattr(view_node, 'sharding_spec', sharding_spec)
gm = ColoGraphModule(model, graph)
gm = node_args_converting_pass(gm, device_mesh)
gm = insert_narrow(gm, x_node)
gm.recompile()
output = gm(input)
assert output.shape == torch.Size([2, 4, 2])
if __name__ == '__main__':
test_node_args_converting_pass()

View File

@@ -0,0 +1,65 @@
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
class TestModule(torch.nn.Module):
def forward(self, x):
size = x.size()
return size
def insert_narrow(gm, x_node):
graph = gm.graph
with graph.inserting_after(x_node):
shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={})
size_node = list(x_node.users.keys())[0]
size_node.args = (shard_node,)
return gm
def recover_narrow(gm, narrow_node):
graph = gm.graph
size_node = list(graph.nodes)[2]
x_node = narrow_node.args[0]
size_node.args = (x_node,)
graph.erase_node(narrow_node)
return gm
def test_size_value_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
meta_args = {'x': torch.rand(4, 8).to('meta')}
input = torch.rand(4, 8)
tracer = ColoTracer()
graph = tracer.trace(root=model, meta_args=meta_args)
x_node = list(graph.nodes)[0]
x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
setattr(x_node, 'sharding_spec', x_sharding_spec)
gm = ColoGraphModule(model, graph)
gm = insert_narrow(gm, x_node)
gm.recompile()
size = gm(input)
assert size == torch.Size([2, 8])
narrow_node = list(gm.graph.nodes)[1]
gm = recover_narrow(gm, narrow_node)
gm = size_value_converting_pass(gm, device_mesh)
gm = insert_narrow(gm, x_node)
gm.recompile()
size = gm(input)
assert size == torch.Size([4, 8])
if __name__ == '__main__':
test_size_value_converting_pass()

View File

@@ -1,12 +1,9 @@
from faulthandler import disable
from functools import partial
from xml.dom import WrongDocumentErr
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from typing_extensions import Self
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (