[ci] fixed ddp test (#5254)

* [ci] fixed ddp test

* polish
This commit is contained in:
Frank Lee
2024-01-11 17:16:32 +08:00
committed by GitHub
parent d5eeeb1416
commit 2b83418719
2 changed files with 17 additions and 3 deletions

View File

@@ -61,7 +61,7 @@ class ModelZooRegistry(dict):
"""
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
def get_sub_registry(self, keyword: Union[str, List[str]]):
def get_sub_registry(self, keyword: Union[str, List[str]], exclude: Union[str, List[str]] = None):
"""
Get a sub registry with models that contain the keyword.
@@ -76,10 +76,24 @@ class ModelZooRegistry(dict):
keyword_list = keyword
assert isinstance(keyword_list, (list, tuple))
if exclude is None:
exclude_keywords = []
elif isinstance(exclude, str):
exclude_keywords = [exclude]
else:
exclude_keywords = exclude
assert isinstance(exclude_keywords, (list, tuple))
for k, v in self.items():
for kw in keyword_list:
if kw in k:
new_dict[k] = v
should_exclude = False
for ex_kw in exclude_keywords:
if ex_kw in k:
should_exclude = True
if not should_exclude:
new_dict[k] = v
assert len(new_dict) > 0, f"No model found with keyword {keyword}"
return new_dict