mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
@@ -61,7 +61,7 @@ class ModelZooRegistry(dict):
|
||||
"""
|
||||
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
|
||||
|
||||
def get_sub_registry(self, keyword: Union[str, List[str]]):
|
||||
def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None):
|
||||
"""
|
||||
Get a sub registry with models that contain the keyword.
|
||||
|
||||
@@ -76,10 +76,24 @@ class ModelZooRegistry(dict):
|
||||
keyword_list = keyword
|
||||
assert isinstance(keyword_list, (list, tuple))
|
||||
|
||||
if exclude is None:
|
||||
exclude_keywords = []
|
||||
elif isinstance(exclude, str):
|
||||
exclude_keywords = [exclude]
|
||||
else:
|
||||
exclude_keywords = exclude
|
||||
assert isinstance(exclude_keywords, (list, tuple))
|
||||
|
||||
for k, v in self.items():
|
||||
for kw in keyword_list:
|
||||
if kw in k:
|
||||
new_dict[k] = v
|
||||
should_exclude = False
|
||||
for ex_kw in exclude_keywords:
|
||||
if ex_kw in k:
|
||||
should_exclude = True
|
||||
|
||||
if not should_exclude:
|
||||
new_dict[k] = v
|
||||
|
||||
assert len(new_dict) > 0, f"No model found with keyword {keyword}"
|
||||
return new_dict
|
||||
|
Reference in New Issue
Block a user