mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
[fx] support module with bias addition (#1780)
* [autoparallel] refactor tracer to fix bias addition issue * [fx] support module with bias addition * create bias_addition_module * refactor file structure * polish code * fix unit test
This commit is contained in:
114
tests/test_fx/test_tracer/test_bias_addition_module.py
Normal file
114
tests/test_fx/test_tracer/test_bias_addition_module.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
||||
|
||||
class LinearModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = x * 2
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x * 2
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def test_linear_module():
|
||||
model = LinearModel(3, 6)
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
|
||||
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
|
||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
|
||||
# return mul
|
||||
graph = tracer.trace(root=model, meta_args={'x': torch.rand(3, 3).to('meta')})
|
||||
# def forward(self, x : torch.Tensor):
|
||||
# linear_weight = self.linear.weight
|
||||
# linear_bias = self.linear.bias
|
||||
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
|
||||
# add = linear + linear_bias; linear = linear_bias = None
|
||||
# mul = add * 2; add = None
|
||||
# return mul
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
node_list = list(graph.nodes)
|
||||
for node in node_list:
|
||||
if node.op == 'output':
|
||||
continue
|
||||
assert hasattr(node, '_meta_data')
|
||||
weight_node = node_list[1]
|
||||
bias_node = node_list[2]
|
||||
linear_node = node_list[3]
|
||||
add_node = node_list[4]
|
||||
assert weight_node._meta_data.shape == (6, 3)
|
||||
assert bias_node._meta_data.shape == (6,)
|
||||
assert linear_node._meta_data.shape == (3, 6)
|
||||
assert add_node._meta_data.shape == (3, 6)
|
||||
|
||||
|
||||
def test_conv_module():
|
||||
model = ConvModel(3, 6, 2)
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
|
||||
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
|
||||
# return mul
|
||||
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')})
|
||||
# def forward(self, x : torch.Tensor):
|
||||
# conv_weight = self.conv.weight
|
||||
# conv_bias = self.conv.bias
|
||||
# conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None
|
||||
# view = conv_bias.view([1, -1, 1, 1]); conv_bias = None
|
||||
# add = conv2d + view; conv2d = view = None
|
||||
# mul = add * 2; add = None
|
||||
# return mul
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
gm.recompile()
|
||||
node_list = list(graph.nodes)
|
||||
for node in node_list:
|
||||
if node.op == 'output':
|
||||
continue
|
||||
assert hasattr(node, '_meta_data')
|
||||
weight_node = node_list[1]
|
||||
bias_node = node_list[2]
|
||||
conv_node = node_list[3]
|
||||
view_node = node_list[4]
|
||||
add_node = node_list[5]
|
||||
assert weight_node._meta_data.shape == (6, 3, 2, 2)
|
||||
assert bias_node._meta_data.shape == (6,)
|
||||
assert conv_node._meta_data.shape == (4, 6, 63, 63)
|
||||
assert view_node._meta_data.shape == (1, 6, 1, 1)
|
||||
assert add_node._meta_data.shape == (4, 6, 63, 63)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_linear_module()
|
||||
test_conv_module()
|
Reference in New Issue
Block a user