diff --git a/pilot/model/proxy/llms/spark.py b/pilot/model/proxy/llms/spark.py index dfcfd9ff7..47b665b3c 100644 --- a/pilot/model/proxy/llms/spark.py +++ b/pilot/model/proxy/llms/spark.py @@ -4,7 +4,6 @@ import base64 import hmac import hashlib import websockets -import asyncio from datetime import datetime from typing import List from time import mktime @@ -73,13 +72,12 @@ def spark_generate_stream( }, "payload": { "message": { - "text": last_user_input.get("") + "text": last_user_input.get("content") } } } - # TODO - + async_call(request_url, data) async def async_call(request_url, data): async with websockets.connect(request_url) as ws: diff --git a/pilot/model/proxy/llms/zhipu.py b/pilot/model/proxy/llms/zhipu.py index 50d2b4080..5c97393ca 100644 --- a/pilot/model/proxy/llms/zhipu.py +++ b/pilot/model/proxy/llms/zhipu.py @@ -1,5 +1,11 @@ -from pilot.model.proxy.llms.proxy_model import ProxyModel +import os +import json +from typing import List +from pilot.model.proxy.llms.proxy_model import ProxyModel +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType + +CHATGLM_DEFAULT_MODEL = "chatglm_pro" def zhipu_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 @@ -9,10 +15,33 @@ def zhipu_generate_stream( print(f"Model: {model}, model_params: {model_params}") proxy_api_key = model_params.proxy_api_key - proxy_server_url = model_params.proxy_server_url - proxyllm_backend = model_params.proxyllm_backend + proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend - if not proxyllm_backend: - proxyllm_backend = "chatglm_pro" - # TODO - yield "Zhipu LLM was not supported!" + import zhipuai + zhipuai.api_key = proxy_api_key + + history = [] + + messages: List[ModelMessage] = params["messages"] + # Add history conversation + + for message in messages: + if message.role == ModelMessageRoleType.HUMAN: + history.append({"role": "user", "content": message.content}) + elif message.role == ModelMessageRoleType.SYSTEM: + history.append({"role": "system", "content": message.content}) + elif message.role == ModelMessageRoleType.AI: + history.append({"role": "assistant", "content": message.content}) + else: + pass + + res = zhipuai.model_api.sse_invoke( + model=proxyllm_backend, + prompt=history, + temperature=params.get("temperature"), + top_p=params.get("top_p"), + incremental=False, + ) + for r in res.events(): + if r.event == "add": + yield r.data \ No newline at end of file