Files
ColossalAI/colossalai/fx/tracer/meta_patch/patched_module.py
Frank Lee 6d86f1bc91 [fx] supported data-dependent control flow in model tracing (#1185)
* [fx] supported data-dependent control flow in model tracing

* polish code
2022-06-29 15:05:25 +08:00

8 lines
217 B
Python

import torch
from .registry import meta_patched_module
@meta_patched_module.register(torch.nn.Linear)
def torch_nn_linear(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")