diff --git a/libs/langchain/langchain/chat_models/ernie.py b/libs/langchain/langchain/chat_models/ernie.py index d3fdce5c31f..367341c11f3 100644 --- a/libs/langchain/langchain/chat_models/ernie.py +++ b/libs/langchain/langchain/chat_models/ernie.py @@ -98,12 +98,19 @@ class ErnieBotChat(BaseChatModel): def _chat(self, payload: object) -> dict: base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat" - if self.model_name == "ERNIE-Bot-turbo": - url = f"{base_url}/eb-instant" - elif self.model_name == "ERNIE-Bot": - url = f"{base_url}/completions" + model_paths = { + "ERNIE-Bot-turbo": "eb-instant", + "ERNIE-Bot": "completions", + "BLOOMZ-7B": "bloomz_7b1", + "Llama-2-7b-chat": "llama_2_7b", + "Llama-2-13b-chat": "llama_2_13b", + "Llama-2-70b-chat": "llama_2_70b", + } + if self.model_name in model_paths: + url = f"{base_url}/{model_paths[self.model_name]}" else: raise ValueError(f"Got unknown model_name {self.model_name}") + resp = requests.post( url, timeout=self.request_timeout,