[fx] support module with bias addition (#1780)

* [autoparallel] refactor tracer to fix bias addition issue

* [fx] support module with bias addition

* create bias_addition_module

* refactor file structure

* polish code

* fix unit test
This commit is contained in:
YuliangLiu0306
2022-11-01 22:53:51 +08:00
committed by GitHub
parent f3f19a5c47
commit e859380bf7
41 changed files with 624 additions and 259 deletions

View File

@@ -1,7 +1,8 @@
from colossalai.fx import ColoTracer
import torch
from torch.fx import GraphModule, Tracer
from colossalai.fx import ColoTracer
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
data = data_gen()
@@ -24,8 +25,9 @@ def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwa
fx_out = gm(**data)
if isinstance(fx_out, tuple):
for non_fx, fx in zip(non_fx_out, fx_out):
assert torch.allclose(non_fx,
fx), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
assert torch.allclose(
non_fx, fx, atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
else:
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
fx_out, non_fx_out,
atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'