mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[autoparallel] refactored the autoparallel module for organization (#1706)
* [autoparallel] refactored the autoparallel module for organization * polish code
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
class Registry:
|
||||
# TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.store = {}
|
||||
|
||||
def register(self, source):
|
||||
|
||||
def wrapper(func):
|
||||
self.store[source] = func
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def get(self, source):
|
||||
assert source in self.store, f'{source} not found in the {self.name} registry'
|
||||
target = self.store[source]
|
||||
return target
|
||||
|
||||
def has(self, source):
|
||||
return source in self.store
|
||||
|
||||
|
||||
operator_registry = Registry('operator')
|
Reference in New Issue
Block a user