[fx] refactor tracer to trace complete graph (#1342)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx] refactor tracer to trace complete graph

* add comments and solve conflicts.
This commit is contained in:
YuliangLiu0306
2022-07-20 11:20:38 +08:00
committed by GitHub
parent 2cc1175c76
commit 942c8cd1fb
9 changed files with 160 additions and 20 deletions

View File

@@ -1,17 +1,40 @@
import torch
import torch.nn as nn
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from torch.fx import GraphModule
import pytest
@pytest.mark.skip('skip due to tracer')
class Conv1D(nn.Module):
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.shape[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x
def test_coloproxy():
# create a dummy node only for testing purpose
model = torch.nn.Linear(10, 10)
gm = torch.fx.symbolic_trace(model)
tracer = ColoTracer()
model = Conv1D(3, 3)
input_sample = {'x': torch.rand(3, 3).to('meta')}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
node = list(gm.graph.nodes)[0]
# create proxy
proxy = ColoProxy(node=node)
proxy = ColoProxy(node=node, tracer=tracer)
proxy.meta_data = torch.empty(4, 2, device='meta')
assert len(proxy) == 4