From 14cfc34c723681c207833081b86e076658b1fddd Mon Sep 17 00:00:00 2001 From: csunny Date: Tue, 10 Oct 2023 20:48:53 +0800 Subject: [PATCH] feat: + tongyi and wenxin --- examples/tongyi.py | 49 ++++++++++++++++++ pilot/configs/config.py | 3 ++ pilot/model/proxy/llms/tongyi.py | 61 ++++++++++++++++++++++- pilot/model/proxy/llms/wenxin.py | 85 +++++++++++++++++++++++++++++++- tests/unit_tests/llms/wenxin.py | 0 5 files changed, 196 insertions(+), 2 deletions(-) create mode 100644 examples/tongyi.py create mode 100644 tests/unit_tests/llms/wenxin.py diff --git a/examples/tongyi.py b/examples/tongyi.py new file mode 100644 index 000000000..9240b430a --- /dev/null +++ b/examples/tongyi.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import dashscope +import requests +from http import HTTPStatus +from dashscope import Generation + +def call_with_messages(): + messages = [{'role': 'system', 'content': '你是生活助手机器人。'}, + {'role': 'user', 'content': '如何做西红柿鸡蛋?'}] + gen = Generation() + response = gen.call( + Generation.Models.qwen_turbo, + messages=messages, + stream=True, + top_p=0.8, + result_format='message', # set the result to be "message" format. + ) + + for response in response: + # The response status_code is HTTPStatus.OK indicate success, + # otherwise indicate request is failed, you can get error code + # and message from code and message. + if response.status_code == HTTPStatus.OK: + print(response.output) # The output text + print(response.usage) # The usage information + else: + print(response.code) # The error code. + print(response.message) # The error message. + + + +def build_access_token(api_key: str, secret_key: str) -> str: + """ + Generate Access token according AK, SK + """ + + url = "https://aip.baidubce.com/oauth/2.0/token" + params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} + + res = requests.get(url=url, params=params) + + if res.status_code == 200: + return res.json().get("access_token") + + +if __name__ == '__main__': + call_with_messages() \ No newline at end of file diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 134bd66f6..cbf791489 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -46,10 +46,13 @@ class Config(metaclass=Singleton): # This is a proxy server, just for test_py. we will remove this later. self.proxy_api_key = os.getenv("PROXY_API_KEY") self.bard_proxy_api_key = os.getenv("BARD_PROXY_API_KEY") + # In order to be compatible with the new and old model parameter design if self.bard_proxy_api_key: os.environ["bard_proxyllm_proxy_api_key"] = self.bard_proxy_api_key + self.tongyi_api_key = os.getenv("TONGYI_PROXY_API_KEY") + self.proxy_server_url = os.getenv("PROXY_SERVER_URL") self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY") diff --git a/pilot/model/proxy/llms/tongyi.py b/pilot/model/proxy/llms/tongyi.py index 0c530cdcf..631424ae7 100644 --- a/pilot/model/proxy/llms/tongyi.py +++ b/pilot/model/proxy/llms/tongyi.py @@ -1,7 +1,66 @@ +import os +import logging +from typing import List from pilot.model.proxy.llms.proxy_model import ProxyModel +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType +logger = logging.getLogger(__name__) def tongyi_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): - yield "tongyi LLM was not supported!" + + import dashscope + from dashscope import Generation + model_params = model.get_params() + print(f"Model: {model}, model_params: {model_params}") + + # proxy_api_key = model_params.proxy_api_key # // TODO Set this according env + dashscope.api_key = os.getenv("TONGYI_PROXY_API_KEY") + + + proxyllm_backend = model_params.proxyllm_backend + if not proxyllm_backend: + proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo + + history = [] + + messages: List[ModelMessage] = params["messages"] + # Add history conversation + + for message in messages: + if message.role == ModelMessageRoleType.HUMAN: + history.append({"role": "user", "content": message.content}) + elif message.role == ModelMessageRoleType.SYSTEM: + history.append({"role": "system", "content": message.content}) + elif message.role == ModelMessageRoleType.AI: + history.append({"role": "assistant", "content": message.content}) + else: + pass + + temp_his = history[::-1] + last_user_input = None + for m in temp_his: + if m["role"] == "user": + last_user_input = m + break + + if last_user_input: + history.remove(last_user_input) + history.append(last_user_input) + + print("history", history) + gen = Generation() + res = gen.call( + proxyllm_backend, + messages=history, + top_p=params.get("top_p", 0.8), + stream=True, + result_format='message' + ) + + for r in res: + if r["output"]["choices"][0]["message"].get("content") is not None: + content = r["output"]["choices"][0]["message"].get("content") + yield content + diff --git a/pilot/model/proxy/llms/wenxin.py b/pilot/model/proxy/llms/wenxin.py index 6c621476e..0053005f5 100644 --- a/pilot/model/proxy/llms/wenxin.py +++ b/pilot/model/proxy/llms/wenxin.py @@ -1,7 +1,90 @@ +import os +import requests +import json +from typing import List from pilot.model.proxy.llms.proxy_model import ProxyModel +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType +from cachetools import cached, TTLCache +@cached(TTLCache(1, 1800)) +def _build_access_token(api_key: str, secret_key: str) -> str: + """ + Generate Access token according AK, SK + """ + + url = "https://aip.baidubce.com/oauth/2.0/token" + params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} + + res = requests.get(url=url, params=params) + + if res.status_code == 200: + return res.json().get("access_token") def wenxin_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): - yield "wenxin LLM is not supported!" + MODEL_VERSION = { + "ERNIE-Bot": "completions", + "ERNIE-Bot-turbo": "eb-instant", + } + + model_params = model.get_params() + model_name = os.getenv("WEN_XIN_MODEL_VERSION") + model_version = MODEL_VERSION.get(model_name) + if not model_version: + yield f"Unsupport model version {model_name}" + + proxy_api_key = os.getenv("WEN_XIN_API_KEY") + proxy_api_secret = os.getenv("WEN_XIN_SECRET_KEY") + access_token = _build_access_token(proxy_api_key, proxy_api_secret) + + headers = { + "Content-Type": "application/json", + "Accept": "application/json" + } + + proxy_server_url = f'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}' + + if not access_token: + yield "Failed to get access token. please set the correct api_key and secret key." + + messages: List[ModelMessage] = params["messages"] + + history = [] + # Add history conversation + for message in messages: + if message.role == ModelMessageRoleType.HUMAN: + history.append({"role": "user", "content": message.content}) + elif message.role == ModelMessageRoleType.SYSTEM: + history.append({"role": "system", "content": message.content}) + elif message.role == ModelMessageRoleType.AI: + history.append({"role": "assistant", "content": message.content}) + else: + pass + + payload = { + "messages": history, + "temperature": params.get("temperature"), + "stream": True + } + + text = "" + res = requests.post(proxy_server_url, headers=headers, json=payload, stream=True) + print(f"Send request to {proxy_server_url} with real model {model_name}") + for line in res.iter_lines(): + if line: + if not line.startswith(b"data: "): + error_message = line.decode("utf-8") + yield error_message + else: + json_data = line.split(b": ", 1)[1] + decoded_line = json_data.decode("utf-8") + if decoded_line.lower() != "[DONE]".lower(): + obj = json.loads(json_data) + if obj["result"] is not None: + content = obj["result"] + text += content + yield text + + + \ No newline at end of file diff --git a/tests/unit_tests/llms/wenxin.py b/tests/unit_tests/llms/wenxin.py new file mode 100644 index 000000000..e69de29bb