mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
community[minor]: Update ChatZhipuAI to support GLM-4 model (#16695)
Description: Update `ChatZhipuAI` to support the latest `glm-4` model. Issue: N/A Dependencies: httpx, httpx-sse, PyJWT The previous `ChatZhipuAI` implementation requires the `zhipuai` package, and cannot call the latest GLM model. This is because - The old version `zhipuai==1.*` doesn't support the latest model. - `zhipuai==2.*` requires `pydantic V2`, which is incompatible with 'langchain-community'. This re-implementation invokes the GLM model by sending HTTP requests to [open.bigmodel.cn](https://open.bigmodel.cn/dev/api) via the `httpx` package, and uses the `httpx-sse` package to handle stream events. --------- Co-authored-by: zR <2448370773@qq.com>
This commit is contained in:
parent
d25b5b6f25
commit
a1f3e9f537
@ -32,7 +32,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --upgrade --quiet zhipuai"
|
||||
"%pip install --quiet httpx[socks]==0.24.1 httpx-sse PyJWT"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -85,9 +85,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = ChatZhipuAI(\n",
|
||||
" temperature=0.5,\n",
|
||||
" api_key=zhipuai_api_key,\n",
|
||||
" model=\"chatglm_turbo\",\n",
|
||||
" model=\"glm-4\",\n",
|
||||
" temperature=0.5,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -158,9 +158,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"streaming_chat = ChatZhipuAI(\n",
|
||||
" temperature=0.5,\n",
|
||||
" api_key=zhipuai_api_key,\n",
|
||||
" model=\"chatglm_turbo\",\n",
|
||||
" model=\"glm-4\",\n",
|
||||
" temperature=0.5,\n",
|
||||
" streaming=True,\n",
|
||||
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
||||
")"
|
||||
@ -211,9 +211,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"async_chat = ChatZhipuAI(\n",
|
||||
" temperature=0.5,\n",
|
||||
" api_key=zhipuai_api_key,\n",
|
||||
" model=\"chatglm_turbo\",\n",
|
||||
" model=\"glm-4\",\n",
|
||||
" temperature=0.5,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -280,48 +280,6 @@
|
||||
" ),\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"character_chat = ChatZhipuAI(\n",
|
||||
" api_key=zhipuai_api_key,\n",
|
||||
" meta=meta,\n",
|
||||
" model=\"characterglm\",\n",
|
||||
" streaming=True,\n",
|
||||
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Okay, great! I'm looking forward to it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"Okay, great! I'm looking forward to it.\")"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"character_chat(messages)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -340,10 +298,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,45 +1,158 @@
|
||||
"""ZHIPU AI chat models wrapper."""
|
||||
"""ZhipuAI chat models wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterator, List, Optional, cast
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ref(BaseModel):
|
||||
"""Reference used in CharacterGLM."""
|
||||
|
||||
enable: bool = Field(True)
|
||||
search_query: str = Field("")
|
||||
API_TOKEN_TTL_SECONDS = 3 * 60
|
||||
ZHIPUAI_API_BASE = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
|
||||
|
||||
|
||||
class meta(BaseModel):
|
||||
"""Metadata used in CharacterGLM."""
|
||||
@contextmanager
|
||||
def connect_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator:
|
||||
from httpx_sse import EventSource
|
||||
|
||||
user_info: str = Field("")
|
||||
bot_info: str = Field("")
|
||||
bot_name: str = Field("")
|
||||
user_name: str = Field("User")
|
||||
with client.stream(method, url, **kwargs) as response:
|
||||
yield EventSource(response)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def aconnect_sse(
|
||||
client: Any, method: str, url: str, **kwargs: Any
|
||||
) -> AsyncIterator:
|
||||
from httpx_sse import EventSource
|
||||
|
||||
async with client.stream(method, url, **kwargs) as response:
|
||||
yield EventSource(response)
|
||||
|
||||
|
||||
def _get_jwt_token(api_key: str) -> str:
|
||||
"""Gets JWT token for ZhipuAI API, see 'https://open.bigmodel.cn/dev/api#nosdk'.
|
||||
|
||||
Args:
|
||||
api_key: The API key for ZhipuAI API.
|
||||
|
||||
Returns:
|
||||
The JWT token.
|
||||
"""
|
||||
import jwt
|
||||
|
||||
try:
|
||||
id, secret = api_key.split(".")
|
||||
except ValueError as err:
|
||||
raise ValueError(f"Invalid API key: {api_key}") from err
|
||||
|
||||
payload = {
|
||||
"api_key": id,
|
||||
"exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
|
||||
"timestamp": int(round(time.time() * 1000)),
|
||||
}
|
||||
|
||||
return jwt.encode(
|
||||
payload,
|
||||
secret,
|
||||
algorithm="HS256",
|
||||
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||
)
|
||||
|
||||
|
||||
def _convert_dict_to_message(dct: Dict[str, Any]) -> BaseMessage:
|
||||
role = dct.get("role")
|
||||
content = dct.get("content", "")
|
||||
if role == "system":
|
||||
return SystemMessage(content=content)
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
if role == "assistant":
|
||||
additional_kwargs = {}
|
||||
tool_calls = dct.get("tool_calls", None)
|
||||
if tool_calls is not None:
|
||||
additional_kwargs["tool_calls"] = tool_calls
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
return ChatMessage(role=role, content=content)
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
|
||||
"""Convert a LangChain message to a dictionary.
|
||||
|
||||
Args:
|
||||
message: The LangChain message.
|
||||
|
||||
Returns:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type '{message.__class__.__name__}'.")
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
dct: Dict[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = dct.get("role")
|
||||
content = dct.get("content", "")
|
||||
additional_kwargs = {}
|
||||
tool_calls = dct.get("tool_call", None)
|
||||
if tool_calls is not None:
|
||||
additional_kwargs["tool_calls"] = tool_calls
|
||||
|
||||
if role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
if role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return default_class(content=content)
|
||||
|
||||
|
||||
class ChatZhipuAI(BaseChatModel):
|
||||
"""
|
||||
`ZHIPU AI` large language chat models API.
|
||||
`ZhipuAI` large language chat models API.
|
||||
|
||||
To use, you should have the ``zhipuai`` python package installed.
|
||||
To use, you should have the ``PyJWT`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@ -49,98 +162,11 @@ class ChatZhipuAI(BaseChatModel):
|
||||
zhipuai_chat = ChatZhipuAI(
|
||||
temperature=0.5,
|
||||
api_key="your-api-key",
|
||||
model="chatglm_turbo",
|
||||
model="glm-4"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
zhipuai: Any
|
||||
zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided."""
|
||||
|
||||
model: str = Field("chatglm_turbo")
|
||||
"""
|
||||
Model name to use.
|
||||
-chatglm_turbo:
|
||||
According to the input of natural language instructions to complete a
|
||||
variety of language tasks, it is recommended to use SSE or asynchronous
|
||||
call request interface.
|
||||
-characterglm:
|
||||
It supports human-based role-playing, ultra-long multi-round memory,
|
||||
and thousands of character dialogues. It is widely used in anthropomorphic
|
||||
dialogues or game scenes such as emotional accompaniments, game intelligent
|
||||
NPCS, Internet celebrities/stars/movie and TV series IP clones, digital
|
||||
people/virtual anchors, and text adventure games.
|
||||
"""
|
||||
|
||||
temperature: float = Field(0.95)
|
||||
"""
|
||||
What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot
|
||||
be equal to 0.
|
||||
The larger the value, the more random and creative the output; The smaller
|
||||
the value, the more stable or certain the output will be.
|
||||
You are advised to adjust top_p or temperature parameters based on application
|
||||
scenarios, but do not adjust the two parameters at the same time.
|
||||
"""
|
||||
|
||||
top_p: float = Field(0.7)
|
||||
"""
|
||||
Another method of sampling temperature is called nuclear sampling. The value
|
||||
ranges from 0.0 to 1.0 and cannot be equal to 0 or 1.
|
||||
The model considers the results with top_p probability quality tokens.
|
||||
For example, 0.1 means that the model decoder only considers tokens from the
|
||||
top 10% probability of the candidate set.
|
||||
You are advised to adjust top_p or temperature parameters based on application
|
||||
scenarios, but do not adjust the two parameters at the same time.
|
||||
"""
|
||||
|
||||
request_id: Optional[str] = Field(None)
|
||||
"""
|
||||
Parameter transmission by the client must ensure uniqueness; A unique
|
||||
identifier used to distinguish each request, which is generated by default
|
||||
by the platform when the client does not transmit it.
|
||||
"""
|
||||
|
||||
streaming: bool = Field(False)
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
incremental: bool = Field(True)
|
||||
"""
|
||||
When invoked by the SSE interface, it is used to control whether the content
|
||||
is returned incremented or full each time.
|
||||
If this parameter is not provided, the value is returned incremented by default.
|
||||
"""
|
||||
|
||||
return_type: str = Field("json_string")
|
||||
"""
|
||||
This parameter is used to control the type of content returned each time.
|
||||
- json_string Returns a standard JSON string.
|
||||
- text Returns the original text content.
|
||||
"""
|
||||
|
||||
ref: Optional[ref] = Field(None)
|
||||
"""
|
||||
This parameter is used to control the reference of external information
|
||||
during the request.
|
||||
Currently, this parameter is used to control whether to reference external
|
||||
information.
|
||||
If this field is empty or absent, the search and parameter passing format
|
||||
is enabled by default.
|
||||
{"enable": "true", "search_query": "history "}
|
||||
"""
|
||||
|
||||
meta: Optional[meta] = Field(None)
|
||||
"""Used in CharacterGLM"""
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {"model_name": self.model}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return the type of chat model."""
|
||||
return "zhipuai"
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"zhipuai_api_key": "ZHIPUAI_API_KEY"}
|
||||
@ -154,93 +180,109 @@ class ChatZhipuAI(BaseChatModel):
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
|
||||
if self.model:
|
||||
attributes["model"] = self.model
|
||||
|
||||
if self.streaming:
|
||||
attributes["streaming"] = self.streaming
|
||||
|
||||
if self.return_type:
|
||||
attributes["return_type"] = self.return_type
|
||||
if self.zhipuai_api_base:
|
||||
attributes["zhipuai_api_base"] = self.zhipuai_api_base
|
||||
|
||||
return attributes
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
try:
|
||||
import zhipuai
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return the type of chat model."""
|
||||
return "zhipuai-chat"
|
||||
|
||||
self.zhipuai = zhipuai
|
||||
self.zhipuai.api_key = self.zhipuai_api_key
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"Could not import zhipuai package. "
|
||||
"Please install it via 'pip install zhipuai'"
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"stream": self.streaming,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
if self.max_tokens is not None:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
return params
|
||||
|
||||
# client:
|
||||
zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided."""
|
||||
zhipuai_api_base: Optional[str] = Field(default=None, alias="api_base")
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator.
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = Field(default="glm-4", alias="model")
|
||||
"""
|
||||
Model name to use, see 'https://open.bigmodel.cn/dev/api#language'.
|
||||
or you can use any finetune model of glm series.
|
||||
"""
|
||||
|
||||
temperature: float = 0.95
|
||||
"""
|
||||
What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot
|
||||
be equal to 0.
|
||||
The larger the value, the more random and creative the output; The smaller
|
||||
the value, the more stable or certain the output will be.
|
||||
You are advised to adjust top_p or temperature parameters based on application
|
||||
scenarios, but do not adjust the two parameters at the same time.
|
||||
"""
|
||||
|
||||
top_p: float = 0.7
|
||||
"""
|
||||
Another method of sampling temperature is called nuclear sampling. The value
|
||||
ranges from 0.0 to 1.0 and cannot be equal to 0 or 1.
|
||||
The model considers the results with top_p probability quality tokens.
|
||||
For example, 0.1 means that the model decoder only considers tokens from the
|
||||
top 10% probability of the candidate set.
|
||||
You are advised to adjust top_p or temperature parameters based on application
|
||||
scenarios, but do not adjust the two parameters at the same time.
|
||||
"""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
values["zhipuai_api_key"] = get_from_dict_or_env(
|
||||
values, "zhipuai_api_key", "ZHIPUAI_API_KEY"
|
||||
)
|
||||
values["zhipuai_api_base"] = get_from_dict_or_env(
|
||||
values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE
|
||||
)
|
||||
|
||||
def invoke(self, prompt: Any) -> Any: # type: ignore[override]
|
||||
if self.model == "chatglm_turbo":
|
||||
return self.zhipuai.model_api.invoke(
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
top_p=self.top_p,
|
||||
temperature=self.temperature,
|
||||
request_id=self.request_id,
|
||||
return_type=self.return_type,
|
||||
)
|
||||
elif self.model == "characterglm":
|
||||
_meta = cast(meta, self.meta).dict()
|
||||
return self.zhipuai.model_api.invoke(
|
||||
model=self.model,
|
||||
meta=_meta,
|
||||
prompt=prompt,
|
||||
request_id=self.request_id,
|
||||
return_type=self.return_type,
|
||||
)
|
||||
return None
|
||||
return values
|
||||
|
||||
def sse_invoke(self, prompt: Any) -> Any:
|
||||
if self.model == "chatglm_turbo":
|
||||
return self.zhipuai.model_api.sse_invoke(
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
top_p=self.top_p,
|
||||
temperature=self.temperature,
|
||||
request_id=self.request_id,
|
||||
return_type=self.return_type,
|
||||
incremental=self.incremental,
|
||||
)
|
||||
elif self.model == "characterglm":
|
||||
_meta = cast(meta, self.meta).dict()
|
||||
return self.zhipuai.model_api.sse_invoke(
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
meta=_meta,
|
||||
request_id=self.request_id,
|
||||
return_type=self.return_type,
|
||||
incremental=self.incremental,
|
||||
)
|
||||
return None
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
async def async_invoke(self, prompt: Any) -> Any:
|
||||
loop = asyncio.get_running_loop()
|
||||
partial_func = partial(
|
||||
self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt
|
||||
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
|
||||
generations = []
|
||||
if not isinstance(response, dict):
|
||||
response = response.dict()
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
||||
generations.append(
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
)
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
partial_func,
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_invoke_result(self, task_id: Any) -> Any:
|
||||
loop = asyncio.get_running_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
self.zhipuai.model_api.query_async_invoke_result,
|
||||
task_id,
|
||||
)
|
||||
return response
|
||||
token_usage = response.get("usage", {})
|
||||
llm_output = {
|
||||
"token_usage": token_usage,
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -251,86 +293,163 @@ class ChatZhipuAI(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate a chat response."""
|
||||
prompt: List = []
|
||||
for message in messages:
|
||||
if isinstance(message, AIMessage):
|
||||
role = "assistant"
|
||||
else: # For both HumanMessage and SystemMessage, role is 'user'
|
||||
role = "user"
|
||||
|
||||
prompt.append({"role": role, "content": message.content})
|
||||
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if not should_stream:
|
||||
response = self.invoke(prompt)
|
||||
|
||||
if response["code"] != 200:
|
||||
raise RuntimeError(response)
|
||||
|
||||
content = response["data"]["choices"][0]["content"]
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=content))]
|
||||
)
|
||||
|
||||
else:
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
async def _agenerate( # type: ignore[override]
|
||||
if self.zhipuai_api_key is None:
|
||||
raise ValueError("Did not find zhipuai_api_key.")
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
payload = {
|
||||
**params,
|
||||
**kwargs,
|
||||
"messages": message_dicts,
|
||||
"stream": False,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||
"Accept": "application/json",
|
||||
}
|
||||
import httpx
|
||||
|
||||
with httpx.Client(headers=headers) as client:
|
||||
response = client.post(self.zhipuai_api_base, json=payload)
|
||||
response.raise_for_status()
|
||||
return self._create_chat_result(response.json())
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = False,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Asynchronously generate a chat response."""
|
||||
|
||||
prompt = []
|
||||
for message in messages:
|
||||
if isinstance(message, AIMessage):
|
||||
role = "assistant"
|
||||
else: # For both HumanMessage and SystemMessage, role is 'user'
|
||||
role = "user"
|
||||
|
||||
prompt.append({"role": role, "content": message.content})
|
||||
|
||||
invoke_response = await self.async_invoke(prompt)
|
||||
task_id = invoke_response["data"]["task_id"]
|
||||
|
||||
response = await self.async_invoke_result(task_id)
|
||||
while response["data"]["task_status"] != "SUCCESS":
|
||||
await asyncio.sleep(1)
|
||||
response = await self.async_invoke_result(task_id)
|
||||
|
||||
content = response["data"]["choices"][0]["content"]
|
||||
content = json.loads(content)
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=content))]
|
||||
)
|
||||
|
||||
def _stream( # type: ignore[override]
|
||||
self,
|
||||
prompt: List[Dict[str, str]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Stream the chat response in chunks."""
|
||||
response = self.sse_invoke(prompt)
|
||||
if self.zhipuai_api_key is None:
|
||||
raise ValueError("Did not find zhipuai_api_key.")
|
||||
if self.zhipuai_api_base is None:
|
||||
raise ValueError("Did not find zhipu_api_base.")
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
|
||||
headers = {
|
||||
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
for r in response.events():
|
||||
if r.event == "add":
|
||||
delta = r.data
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
default_chunk_class = AIMessageChunk
|
||||
import httpx
|
||||
|
||||
with httpx.Client(headers=headers) as client:
|
||||
with connect_sse(
|
||||
client, "POST", self.zhipuai_api_base, json=payload
|
||||
) as event_source:
|
||||
for sse in event_source.iter_sse():
|
||||
chunk = json.loads(sse.data)
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
|
||||
generation_info = (
|
||||
{"finish_reason": finish_reason}
|
||||
if finish_reason is not None
|
||||
else None
|
||||
)
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
if finish_reason is not None:
|
||||
break
|
||||
|
||||
elif r.event == "error":
|
||||
raise ValueError(f"Error from ZhipuAI API response: {r.data}")
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
if self.zhipuai_api_key is None:
|
||||
raise ValueError("Did not find zhipuai_api_key.")
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
payload = {
|
||||
**params,
|
||||
**kwargs,
|
||||
"messages": message_dicts,
|
||||
"stream": False,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||
"Accept": "application/json",
|
||||
}
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(headers=headers) as client:
|
||||
response = await client.post(self.zhipuai_api_base, json=payload)
|
||||
response.raise_for_status()
|
||||
return self._create_chat_result(response.json())
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
if self.zhipuai_api_key is None:
|
||||
raise ValueError("Did not find zhipuai_api_key.")
|
||||
if self.zhipuai_api_base is None:
|
||||
raise ValueError("Did not find zhipu_api_base.")
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
|
||||
headers = {
|
||||
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(headers=headers) as client:
|
||||
async with aconnect_sse(
|
||||
client, "POST", self.zhipuai_api_base, json=payload
|
||||
) as event_source:
|
||||
async for sse in event_source.aiter_sse():
|
||||
chunk = json.loads(sse.data)
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
|
||||
generation_info = (
|
||||
{"finish_reason": finish_reason}
|
||||
if finish_reason is not None
|
||||
else None
|
||||
)
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
if finish_reason is not None:
|
||||
break
|
||||
|
32
libs/community/poetry.lock
generated
32
libs/community/poetry.lock
generated
@ -1407,17 +1407,6 @@ mlflow-skinny = ">=2.4.0,<3"
|
||||
protobuf = ">=3.12.0,<5"
|
||||
requests = ">=2"
|
||||
|
||||
[[package]]
|
||||
name = "dataclasses"
|
||||
version = "0.6"
|
||||
description = "A backport of the dataclasses module for Python 3.6"
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "dataclasses-0.6-py3-none-any.whl", hash = "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f"},
|
||||
{file = "dataclasses-0.6.tar.gz", hash = "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dataclasses-json"
|
||||
version = "0.6.4"
|
||||
@ -9229,23 +9218,6 @@ files = [
|
||||
idna = ">=2.0"
|
||||
multidict = ">=4.0"
|
||||
|
||||
[[package]]
|
||||
name = "zhipuai"
|
||||
version = "1.0.7"
|
||||
description = "A SDK library for accessing big model apis from ZhipuAI"
|
||||
optional = true
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "zhipuai-1.0.7-py3-none-any.whl", hash = "sha256:360c01b8c2698f366061452e86d5a36a5ff68a576ea33940da98e4806f232530"},
|
||||
{file = "zhipuai-1.0.7.tar.gz", hash = "sha256:b80f699543d83cce8648acf1ce32bc2725d1c1c443baffa5882abc2cc704d581"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cachetools = "*"
|
||||
dataclasses = "*"
|
||||
PyJWT = "*"
|
||||
requests = "*"
|
||||
|
||||
[[package]]
|
||||
name = "zipp"
|
||||
version = "3.17.0"
|
||||
@ -9263,9 +9235,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
|
||||
[extras]
|
||||
cli = ["typer"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict", "zhipuai"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "45da04abac45743972d1edf62d08d9abaa2bebb473b794e0a0d6f1fdc87773f9"
|
||||
content-hash = "67c38c029bb59d45fd0f84a5d48c44f64f1301d6be07f419615d08ba8671a2a7"
|
||||
|
@ -88,7 +88,6 @@ tree-sitter = {version = "^0.20.2", optional = true}
|
||||
tree-sitter-languages = {version = "^1.8.0", optional = true}
|
||||
azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true}
|
||||
oracle-ads = {version = "^2.9.1", optional = true}
|
||||
zhipuai = {version = "^1.0.7", optional = true}
|
||||
httpx = {version = "^0.24.1", optional = true}
|
||||
elasticsearch = {version = "^8.12.0", optional = true}
|
||||
hdbcli = {version = "^2.19.21", optional = true}
|
||||
@ -99,6 +98,8 @@ tidb-vector = {version = ">=0.0.3,<1.0.0", optional = true}
|
||||
friendli-client = {version = "^1.2.4", optional = true}
|
||||
premai = {version = "^0.3.25", optional = true}
|
||||
vdms = {version = "^0.0.20", optional = true}
|
||||
httpx-sse = {version = "^0.4.0", optional = true}
|
||||
pyjwt = {version = "^2.8.0", optional = true}
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
@ -262,7 +263,6 @@ extended_testing = [
|
||||
"tree-sitter-languages",
|
||||
"azure-ai-documentintelligence",
|
||||
"oracle-ads",
|
||||
"zhipuai",
|
||||
"httpx",
|
||||
"elasticsearch",
|
||||
"hdbcli",
|
||||
@ -272,7 +272,9 @@ extended_testing = [
|
||||
"cloudpickle",
|
||||
"friendli-client",
|
||||
"premai",
|
||||
"vdms"
|
||||
"vdms",
|
||||
"httpx-sse",
|
||||
"pyjwt"
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@ -18,7 +18,7 @@ def test_default_call() -> None:
|
||||
|
||||
def test_model() -> None:
|
||||
"""Test model kwarg works."""
|
||||
chat = ChatZhipuAI(model="chatglm_turbo")
|
||||
chat = ChatZhipuAI(model="glm-4")
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
@ -1,10 +1,13 @@
|
||||
"""Test ZhipuAI Chat API wrapper"""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.chat_models.zhipuai import ChatZhipuAI
|
||||
|
||||
|
||||
@pytest.mark.requires("zhipuai")
|
||||
def test_integration_initialization() -> None:
|
||||
chat = ChatZhipuAI(model="chatglm_turbo", streaming=False)
|
||||
assert chat.model == "chatglm_turbo"
|
||||
assert chat.streaming is False
|
||||
@pytest.mark.requires("httpx", "httpx_sse", "jwt")
|
||||
def test_zhipuai_model_param() -> None:
|
||||
llm = ChatZhipuAI(api_key="test", model="foo")
|
||||
assert llm.model_name == "foo"
|
||||
llm = ChatZhipuAI(api_key="test", model_name="foo")
|
||||
assert llm.model_name == "foo"
|
||||
|
Loading…
Reference in New Issue
Block a user