diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index 4632fe39e..6f49242f7 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -144,6 +144,7 @@ LLM_MODEL_CONFIG = { "openchat-3.5": os.path.join(MODEL_PATH, "openchat_3.5"), # https://huggingface.co/openchat/openchat-3.5-1210 "openchat-3.5-1210": os.path.join(MODEL_PATH, "openchat-3.5-1210"), + "openchat-3.6-8b-20240522": os.path.join(MODEL_PATH, "openchat-3.6-8b-20240522"), # https://huggingface.co/hfl/chinese-alpaca-2-7b "chinese-alpaca-2-7b": os.path.join(MODEL_PATH, "chinese-alpaca-2-7b"), # https://huggingface.co/hfl/chinese-alpaca-2-13b diff --git a/dbgpt/model/adapter/hf_adapter.py b/dbgpt/model/adapter/hf_adapter.py index d4eefa99a..ec24965fc 100644 --- a/dbgpt/model/adapter/hf_adapter.py +++ b/dbgpt/model/adapter/hf_adapter.py @@ -464,6 +464,22 @@ class SQLCoderAdapter(Llama3Adapter): ) +class OpenChatAdapter(Llama3Adapter): + """ + https://huggingface.co/openchat/openchat-3.6-8b-20240522 + """ + + support_4bit: bool = True + support_8bit: bool = True + + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return ( + lower_model_name_or_path + and "openchat" in lower_model_name_or_path + and "3.6" in lower_model_name_or_path + ) + + # The following code is used to register the model adapter # The last registered model adapter is matched first register_model_adapter(YiAdapter) @@ -479,3 +495,4 @@ register_model_adapter(DeepseekV2Adapter) register_model_adapter(SailorAdapter) register_model_adapter(PhiAdapter) register_model_adapter(SQLCoderAdapter) +register_model_adapter(OpenChatAdapter)