add bloomz_7b, llama-2-7b, llama-2-13b, llama-2-70b to ErnieBotChat (#10024)

- Description: Add bloomz_7b, llama-2-7b, llama-2-13b, llama-2-70b to
ErnieBotChat, which only supported ERNIE-Bot-turbo and ERNIE-Bot.
  - Issue: #10022,
  - Dependencies: no extra dependencies

---------

Co-authored-by: hetianfeng <hetianfeng@meituan.com>
This commit is contained in:
Hunsmore 2023-08-31 15:38:55 +08:00 committed by GitHub
parent e37d51cab6
commit 13fef1e5d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -98,12 +98,19 @@ class ErnieBotChat(BaseChatModel):
def _chat(self, payload: object) -> dict: def _chat(self, payload: object) -> dict:
base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat" base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat"
if self.model_name == "ERNIE-Bot-turbo": model_paths = {
url = f"{base_url}/eb-instant" "ERNIE-Bot-turbo": "eb-instant",
elif self.model_name == "ERNIE-Bot": "ERNIE-Bot": "completions",
url = f"{base_url}/completions" "BLOOMZ-7B": "bloomz_7b1",
"Llama-2-7b-chat": "llama_2_7b",
"Llama-2-13b-chat": "llama_2_13b",
"Llama-2-70b-chat": "llama_2_70b",
}
if self.model_name in model_paths:
url = f"{base_url}/{model_paths[self.model_name]}"
else: else:
raise ValueError(f"Got unknown model_name {self.model_name}") raise ValueError(f"Got unknown model_name {self.model_name}")
resp = requests.post( resp = requests.post(
url, url,
timeout=self.request_timeout, timeout=self.request_timeout,