mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[fx] support module with bias addition (#1780)
* [autoparallel] refactor tracer to fix bias addition issue * [fx] support module with bias addition * create bias_addition_module * refactor file structure * polish code * fix unit test
This commit is contained in:
27
colossalai/fx/tracer/registry.py
Normal file
27
colossalai/fx/tracer/registry.py
Normal file
@@ -0,0 +1,27 @@
|
||||
class PatchRegistry:
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.store = {}
|
||||
|
||||
def register(self, source):
|
||||
|
||||
def wrapper(func):
|
||||
self.store[source] = func
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def get(self, source):
|
||||
assert source in self.store
|
||||
target = self.store[source]
|
||||
return target
|
||||
|
||||
def has(self, source):
|
||||
return source in self.store
|
||||
|
||||
|
||||
meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution')
|
||||
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
|
||||
bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
|
||||
bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
|
Reference in New Issue
Block a user