mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[fx] supported data-dependent control flow in model tracing (#1185)
* [fx] supported data-dependent control flow in model tracing * polish code
This commit is contained in:
57
tests/test_fx/test_tracer/test_control_flow.py
Normal file
57
tests/test_fx/test_tracer/test_control_flow.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer as Tracer
|
||||
|
||||
|
||||
class ControlFlowModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(10, 10)
|
||||
self.linear2 = nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x, y):
|
||||
x1 = self.linear1(x)
|
||||
y1 = self.linear2(y)
|
||||
|
||||
if x1.dim() == 2:
|
||||
return x1 + y1
|
||||
else:
|
||||
return x1 - y1
|
||||
|
||||
|
||||
def test_control_flow():
|
||||
model = ControlFlowModel()
|
||||
tracer = Tracer()
|
||||
graph_branch_true = tracer.trace(model,
|
||||
meta_args={
|
||||
'x': torch.rand(4, 10, device='meta'),
|
||||
'y': torch.rand(4, 10, device='meta')
|
||||
})
|
||||
graph_branch_false = tracer.trace(model,
|
||||
meta_args={
|
||||
'x': torch.rand(10, device='meta'),
|
||||
'y': torch.rand(4, 10, device='meta')
|
||||
})
|
||||
|
||||
gm_branch_true = GraphModule(model, graph_branch_true, model.__class__.__name__)
|
||||
gm_branch_false = GraphModule(model, graph_branch_false, model.__class__.__name__)
|
||||
gm_branch_true.recompile()
|
||||
gm_branch_false.recompile()
|
||||
|
||||
# test the true branch
|
||||
x = torch.rand(4, 10)
|
||||
y = torch.rand(4, 10)
|
||||
assert torch.all(model(x, y) == gm_branch_true(x, y))
|
||||
assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y))
|
||||
|
||||
# test the true branch
|
||||
x = torch.rand(10)
|
||||
y = torch.rand(4, 10)
|
||||
assert torch.all(model(x, y) == gm_branch_false(x, y))
|
||||
assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_control_flow()
|
Reference in New Issue
Block a user