[workflow] fixed build CI (#5240)

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish
This commit is contained in:
Frank Lee
2024-01-10 22:34:16 +08:00
committed by GitHub
parent 41e52c1c6e
commit edf94a35c3
14 changed files with 101 additions and 156 deletions

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python
from dataclasses import dataclass
from typing import Callable
from typing import Callable, List, Union
__all__ = ["ModelZooRegistry", "ModelAttribute", "model_zoo"]
@@ -61,7 +61,7 @@ class ModelZooRegistry(dict):
"""
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
def get_sub_registry(self, keyword: str):
def get_sub_registry(self, keyword: Union[str, List[str]]):
"""
Get a sub registry with models that contain the keyword.
@@ -70,12 +70,15 @@ class ModelZooRegistry(dict):
"""
new_dict = dict()
if isinstance(keyword, str):
keyword_list = [keyword]
else:
keyword_list = keyword
assert isinstance(keyword_list, (list, tuple))
for k, v in self.items():
if keyword == "transformers_gpt":
if keyword in k and not "gptj" in k: # ensure GPT2 does not retrieve GPTJ models
new_dict[k] = v
else:
if keyword in k:
for kw in keyword_list:
if kw in k:
new_dict[k] = v
assert len(new_dict) > 0, f"No model found with keyword {keyword}"