community[minor]: Update ChatZhipuAI to support GLM-4 model (#16695)

Description: Update `ChatZhipuAI` to support the latest `glm-4` model.
Issue: N/A
Dependencies: httpx, httpx-sse, PyJWT

The previous `ChatZhipuAI` implementation requires the `zhipuai`
package, and cannot call the latest GLM model. This is because
- The old version `zhipuai==1.*` doesn't support the latest model.
- `zhipuai==2.*` requires `pydantic V2`, which is incompatible with
'langchain-community'.

This re-implementation invokes the GLM model by sending HTTP requests to
[open.bigmodel.cn](https://open.bigmodel.cn/dev/api) via the `httpx`
package, and uses the `httpx-sse` package to handle stream events.

---------

Co-authored-by: zR <2448370773@qq.com>
This commit is contained in:
Chenhui Zhang 2024-04-02 02:11:21 +08:00 committed by GitHub
parent d25b5b6f25
commit a1f3e9f537
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 694 additions and 641 deletions

View File

@ -32,7 +32,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"%pip install --upgrade --quiet zhipuai" "%pip install --quiet httpx[socks]==0.24.1 httpx-sse PyJWT"
] ]
}, },
{ {
@ -85,9 +85,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"chat = ChatZhipuAI(\n", "chat = ChatZhipuAI(\n",
" temperature=0.5,\n",
" api_key=zhipuai_api_key,\n", " api_key=zhipuai_api_key,\n",
" model=\"chatglm_turbo\",\n", " model=\"glm-4\",\n",
" temperature=0.5,\n",
")" ")"
] ]
}, },
@ -158,9 +158,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"streaming_chat = ChatZhipuAI(\n", "streaming_chat = ChatZhipuAI(\n",
" temperature=0.5,\n",
" api_key=zhipuai_api_key,\n", " api_key=zhipuai_api_key,\n",
" model=\"chatglm_turbo\",\n", " model=\"glm-4\",\n",
" temperature=0.5,\n",
" streaming=True,\n", " streaming=True,\n",
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n", " callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
")" ")"
@ -211,9 +211,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"async_chat = ChatZhipuAI(\n", "async_chat = ChatZhipuAI(\n",
" temperature=0.5,\n",
" api_key=zhipuai_api_key,\n", " api_key=zhipuai_api_key,\n",
" model=\"chatglm_turbo\",\n", " model=\"glm-4\",\n",
" temperature=0.5,\n",
")" ")"
] ]
}, },
@ -280,48 +280,6 @@
" ),\n", " ),\n",
"]" "]"
] ]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"character_chat = ChatZhipuAI(\n",
" api_key=zhipuai_api_key,\n",
" meta=meta,\n",
" model=\"characterglm\",\n",
" streaming=True,\n",
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Okay, great! I'm looking forward to it."
]
},
{
"data": {
"text/plain": [
"AIMessage(content=\"Okay, great! I'm looking forward to it.\")"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"character_chat(messages)"
]
} }
], ],
"metadata": { "metadata": {
@ -340,10 +298,9 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.4" "version": "3.9.18"
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 4 "nbformat_minor": 4
} }

View File

@ -1,45 +1,158 @@
"""ZHIPU AI chat models wrapper.""" """ZhipuAI chat models wrapper."""
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
import logging import logging
from functools import partial import time
from typing import Any, Dict, Iterator, List, Optional, cast from collections.abc import AsyncIterator, Iterator
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import ( from langchain_core.language_models.chat_models import (
BaseChatModel, BaseChatModel,
agenerate_from_stream,
generate_from_stream, generate_from_stream,
) )
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
API_TOKEN_TTL_SECONDS = 3 * 60
class ref(BaseModel): ZHIPUAI_API_BASE = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
"""Reference used in CharacterGLM."""
enable: bool = Field(True)
search_query: str = Field("")
class meta(BaseModel): @contextmanager
"""Metadata used in CharacterGLM.""" def connect_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator:
from httpx_sse import EventSource
user_info: str = Field("") with client.stream(method, url, **kwargs) as response:
bot_info: str = Field("") yield EventSource(response)
bot_name: str = Field("")
user_name: str = Field("User")
@asynccontextmanager
async def aconnect_sse(
client: Any, method: str, url: str, **kwargs: Any
) -> AsyncIterator:
from httpx_sse import EventSource
async with client.stream(method, url, **kwargs) as response:
yield EventSource(response)
def _get_jwt_token(api_key: str) -> str:
"""Gets JWT token for ZhipuAI API, see 'https://open.bigmodel.cn/dev/api#nosdk'.
Args:
api_key: The API key for ZhipuAI API.
Returns:
The JWT token.
"""
import jwt
try:
id, secret = api_key.split(".")
except ValueError as err:
raise ValueError(f"Invalid API key: {api_key}") from err
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
def _convert_dict_to_message(dct: Dict[str, Any]) -> BaseMessage:
role = dct.get("role")
content = dct.get("content", "")
if role == "system":
return SystemMessage(content=content)
if role == "user":
return HumanMessage(content=content)
if role == "assistant":
additional_kwargs = {}
tool_calls = dct.get("tool_calls", None)
if tool_calls is not None:
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs)
return ChatMessage(role=role, content=content)
def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
else:
raise TypeError(f"Got unknown type '{message.__class__.__name__}'.")
return message_dict
def _convert_delta_to_message_chunk(
dct: Dict[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = dct.get("role")
content = dct.get("content", "")
additional_kwargs = {}
tool_calls = dct.get("tool_call", None)
if tool_calls is not None:
additional_kwargs["tool_calls"] = tool_calls
if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
if role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
return default_class(content=content)
class ChatZhipuAI(BaseChatModel): class ChatZhipuAI(BaseChatModel):
""" """
`ZHIPU AI` large language chat models API. `ZhipuAI` large language chat models API.
To use, you should have the ``zhipuai`` python package installed. To use, you should have the ``PyJWT`` python package installed.
Example: Example:
.. code-block:: python .. code-block:: python
@ -49,98 +162,11 @@ class ChatZhipuAI(BaseChatModel):
zhipuai_chat = ChatZhipuAI( zhipuai_chat = ChatZhipuAI(
temperature=0.5, temperature=0.5,
api_key="your-api-key", api_key="your-api-key",
model="chatglm_turbo", model="glm-4"
) )
""" """
zhipuai: Any
zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided."""
model: str = Field("chatglm_turbo")
"""
Model name to use.
-chatglm_turbo:
According to the input of natural language instructions to complete a
variety of language tasks, it is recommended to use SSE or asynchronous
call request interface.
-characterglm:
It supports human-based role-playing, ultra-long multi-round memory,
and thousands of character dialogues. It is widely used in anthropomorphic
dialogues or game scenes such as emotional accompaniments, game intelligent
NPCS, Internet celebrities/stars/movie and TV series IP clones, digital
people/virtual anchors, and text adventure games.
"""
temperature: float = Field(0.95)
"""
What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot
be equal to 0.
The larger the value, the more random and creative the output; The smaller
the value, the more stable or certain the output will be.
You are advised to adjust top_p or temperature parameters based on application
scenarios, but do not adjust the two parameters at the same time.
"""
top_p: float = Field(0.7)
"""
Another method of sampling temperature is called nuclear sampling. The value
ranges from 0.0 to 1.0 and cannot be equal to 0 or 1.
The model considers the results with top_p probability quality tokens.
For example, 0.1 means that the model decoder only considers tokens from the
top 10% probability of the candidate set.
You are advised to adjust top_p or temperature parameters based on application
scenarios, but do not adjust the two parameters at the same time.
"""
request_id: Optional[str] = Field(None)
"""
Parameter transmission by the client must ensure uniqueness; A unique
identifier used to distinguish each request, which is generated by default
by the platform when the client does not transmit it.
"""
streaming: bool = Field(False)
"""Whether to stream the results or not."""
incremental: bool = Field(True)
"""
When invoked by the SSE interface, it is used to control whether the content
is returned incremented or full each time.
If this parameter is not provided, the value is returned incremented by default.
"""
return_type: str = Field("json_string")
"""
This parameter is used to control the type of content returned each time.
- json_string Returns a standard JSON string.
- text Returns the original text content.
"""
ref: Optional[ref] = Field(None)
"""
This parameter is used to control the reference of external information
during the request.
Currently, this parameter is used to control whether to reference external
information.
If this field is empty or absent, the search and parameter passing format
is enabled by default.
{"enable": "true", "search_query": "history "}
"""
meta: Optional[meta] = Field(None)
"""Used in CharacterGLM"""
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"model_name": self.model}
@property
def _llm_type(self) -> str:
"""Return the type of chat model."""
return "zhipuai"
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> Dict[str, str]:
return {"zhipuai_api_key": "ZHIPUAI_API_KEY"} return {"zhipuai_api_key": "ZHIPUAI_API_KEY"}
@ -154,93 +180,109 @@ class ChatZhipuAI(BaseChatModel):
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {} attributes: Dict[str, Any] = {}
if self.model: if self.zhipuai_api_base:
attributes["model"] = self.model attributes["zhipuai_api_base"] = self.zhipuai_api_base
if self.streaming:
attributes["streaming"] = self.streaming
if self.return_type:
attributes["return_type"] = self.return_type
return attributes return attributes
def __init__(self, *args: Any, **kwargs: Any) -> None: @property
super().__init__(*args, **kwargs) def _llm_type(self) -> str:
try: """Return the type of chat model."""
import zhipuai return "zhipuai-chat"
self.zhipuai = zhipuai @property
self.zhipuai.api_key = self.zhipuai_api_key def _default_params(self) -> Dict[str, Any]:
except ImportError: """Get the default parameters for calling OpenAI API."""
raise RuntimeError( params = {
"Could not import zhipuai package. " "model": self.model_name,
"Please install it via 'pip install zhipuai'" "stream": self.streaming,
"temperature": self.temperature,
}
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
return params
# client:
zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided."""
zhipuai_api_base: Optional[str] = Field(default=None, alias="api_base")
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator.
"""
model_name: Optional[str] = Field(default="glm-4", alias="model")
"""
Model name to use, see 'https://open.bigmodel.cn/dev/api#language'.
or you can use any finetune model of glm series.
"""
temperature: float = 0.95
"""
What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot
be equal to 0.
The larger the value, the more random and creative the output; The smaller
the value, the more stable or certain the output will be.
You are advised to adjust top_p or temperature parameters based on application
scenarios, but do not adjust the two parameters at the same time.
"""
top_p: float = 0.7
"""
Another method of sampling temperature is called nuclear sampling. The value
ranges from 0.0 to 1.0 and cannot be equal to 0 or 1.
The model considers the results with top_p probability quality tokens.
For example, 0.1 means that the model decoder only considers tokens from the
top 10% probability of the candidate set.
You are advised to adjust top_p or temperature parameters based on application
scenarios, but do not adjust the two parameters at the same time.
"""
streaming: bool = False
"""Whether to stream the results or not."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator()
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["zhipuai_api_key"] = get_from_dict_or_env(
values, "zhipuai_api_key", "ZHIPUAI_API_KEY"
)
values["zhipuai_api_base"] = get_from_dict_or_env(
values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE
) )
def invoke(self, prompt: Any) -> Any: # type: ignore[override] return values
if self.model == "chatglm_turbo":
return self.zhipuai.model_api.invoke(
model=self.model,
prompt=prompt,
top_p=self.top_p,
temperature=self.temperature,
request_id=self.request_id,
return_type=self.return_type,
)
elif self.model == "characterglm":
_meta = cast(meta, self.meta).dict()
return self.zhipuai.model_api.invoke(
model=self.model,
meta=_meta,
prompt=prompt,
request_id=self.request_id,
return_type=self.return_type,
)
return None
def sse_invoke(self, prompt: Any) -> Any: def _create_message_dicts(
if self.model == "chatglm_turbo": self, messages: List[BaseMessage], stop: Optional[List[str]]
return self.zhipuai.model_api.sse_invoke( ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
model=self.model, params = self._default_params
prompt=prompt, if stop is not None:
top_p=self.top_p, params["stop"] = stop
temperature=self.temperature, message_dicts = [_convert_message_to_dict(m) for m in messages]
request_id=self.request_id, return message_dicts, params
return_type=self.return_type,
incremental=self.incremental,
)
elif self.model == "characterglm":
_meta = cast(meta, self.meta).dict()
return self.zhipuai.model_api.sse_invoke(
model=self.model,
prompt=prompt,
meta=_meta,
request_id=self.request_id,
return_type=self.return_type,
incremental=self.incremental,
)
return None
async def async_invoke(self, prompt: Any) -> Any: def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
loop = asyncio.get_running_loop() generations = []
partial_func = partial( if not isinstance(response, dict):
self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt response = response.dict()
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
generation_info = dict(finish_reason=res.get("finish_reason"))
generations.append(
ChatGeneration(message=message, generation_info=generation_info)
) )
response = await loop.run_in_executor( token_usage = response.get("usage", {})
None, llm_output = {
partial_func, "token_usage": token_usage,
) "model_name": self.model_name,
return response }
return ChatResult(generations=generations, llm_output=llm_output)
async def async_invoke_result(self, task_id: Any) -> Any:
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None,
self.zhipuai.model_api.query_async_invoke_result,
task_id,
)
return response
def _generate( def _generate(
self, self,
@ -251,86 +293,163 @@ class ChatZhipuAI(BaseChatModel):
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Generate a chat response.""" """Generate a chat response."""
prompt: List = []
for message in messages:
if isinstance(message, AIMessage):
role = "assistant"
else: # For both HumanMessage and SystemMessage, role is 'user'
role = "user"
prompt.append({"role": role, "content": message.content})
should_stream = stream if stream is not None else self.streaming should_stream = stream if stream is not None else self.streaming
if not should_stream: if should_stream:
response = self.invoke(prompt)
if response["code"] != 200:
raise RuntimeError(response)
content = response["data"]["choices"][0]["content"]
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=content))]
)
else:
stream_iter = self._stream( stream_iter = self._stream(
prompt=prompt, messages, stop=stop, run_manager=run_manager, **kwargs
stop=stop,
run_manager=run_manager,
**kwargs,
) )
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
async def _agenerate( # type: ignore[override] if self.zhipuai_api_key is None:
raise ValueError("Did not find zhipuai_api_key.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {
**params,
**kwargs,
"messages": message_dicts,
"stream": False,
}
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
}
import httpx
with httpx.Client(headers=headers) as client:
response = client.post(self.zhipuai_api_base, json=payload)
response.raise_for_status()
return self._create_chat_result(response.json())
def _stream(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = False,
**kwargs: Any,
) -> ChatResult:
"""Asynchronously generate a chat response."""
prompt = []
for message in messages:
if isinstance(message, AIMessage):
role = "assistant"
else: # For both HumanMessage and SystemMessage, role is 'user'
role = "user"
prompt.append({"role": role, "content": message.content})
invoke_response = await self.async_invoke(prompt)
task_id = invoke_response["data"]["task_id"]
response = await self.async_invoke_result(task_id)
while response["data"]["task_status"] != "SUCCESS":
await asyncio.sleep(1)
response = await self.async_invoke_result(task_id)
content = response["data"]["choices"][0]["content"]
content = json.loads(content)
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=content))]
)
def _stream( # type: ignore[override]
self,
prompt: List[Dict[str, str]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
"""Stream the chat response in chunks.""" """Stream the chat response in chunks."""
response = self.sse_invoke(prompt) if self.zhipuai_api_key is None:
raise ValueError("Did not find zhipuai_api_key.")
if self.zhipuai_api_base is None:
raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
}
for r in response.events(): default_chunk_class = AIMessageChunk
if r.event == "add": import httpx
delta = r.data
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) with httpx.Client(headers=headers) as client:
if run_manager: with connect_sse(
run_manager.on_llm_new_token(delta, chunk=chunk) client, "POST", self.zhipuai_api_base, json=payload
) as event_source:
for sse in event_source.iter_sse():
chunk = json.loads(sse.data)
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason", None)
generation_info = (
{"finish_reason": finish_reason}
if finish_reason is not None
else None
)
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
yield chunk yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
if finish_reason is not None:
break
elif r.event == "error": async def _agenerate(
raise ValueError(f"Error from ZhipuAI API response: {r.data}") self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
if self.zhipuai_api_key is None:
raise ValueError("Did not find zhipuai_api_key.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {
**params,
**kwargs,
"messages": message_dicts,
"stream": False,
}
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
}
import httpx
async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(self.zhipuai_api_base, json=payload)
response.raise_for_status()
return self._create_chat_result(response.json())
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
if self.zhipuai_api_key is None:
raise ValueError("Did not find zhipuai_api_key.")
if self.zhipuai_api_base is None:
raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
}
default_chunk_class = AIMessageChunk
import httpx
async with httpx.AsyncClient(headers=headers) as client:
async with aconnect_sse(
client, "POST", self.zhipuai_api_base, json=payload
) as event_source:
async for sse in event_source.aiter_sse():
chunk = json.loads(sse.data)
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason", None)
generation_info = (
{"finish_reason": finish_reason}
if finish_reason is not None
else None
)
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
if finish_reason is not None:
break

View File

@ -1407,17 +1407,6 @@ mlflow-skinny = ">=2.4.0,<3"
protobuf = ">=3.12.0,<5" protobuf = ">=3.12.0,<5"
requests = ">=2" requests = ">=2"
[[package]]
name = "dataclasses"
version = "0.6"
description = "A backport of the dataclasses module for Python 3.6"
optional = true
python-versions = "*"
files = [
{file = "dataclasses-0.6-py3-none-any.whl", hash = "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f"},
{file = "dataclasses-0.6.tar.gz", hash = "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"},
]
[[package]] [[package]]
name = "dataclasses-json" name = "dataclasses-json"
version = "0.6.4" version = "0.6.4"
@ -9229,23 +9218,6 @@ files = [
idna = ">=2.0" idna = ">=2.0"
multidict = ">=4.0" multidict = ">=4.0"
[[package]]
name = "zhipuai"
version = "1.0.7"
description = "A SDK library for accessing big model apis from ZhipuAI"
optional = true
python-versions = ">=3.6"
files = [
{file = "zhipuai-1.0.7-py3-none-any.whl", hash = "sha256:360c01b8c2698f366061452e86d5a36a5ff68a576ea33940da98e4806f232530"},
{file = "zhipuai-1.0.7.tar.gz", hash = "sha256:b80f699543d83cce8648acf1ce32bc2725d1c1c443baffa5882abc2cc704d581"},
]
[package.dependencies]
cachetools = "*"
dataclasses = "*"
PyJWT = "*"
requests = "*"
[[package]] [[package]]
name = "zipp" name = "zipp"
version = "3.17.0" version = "3.17.0"
@ -9263,9 +9235,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[extras] [extras]
cli = ["typer"] cli = ["typer"]
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict", "zhipuai"] extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "45da04abac45743972d1edf62d08d9abaa2bebb473b794e0a0d6f1fdc87773f9" content-hash = "67c38c029bb59d45fd0f84a5d48c44f64f1301d6be07f419615d08ba8671a2a7"

View File

@ -88,7 +88,6 @@ tree-sitter = {version = "^0.20.2", optional = true}
tree-sitter-languages = {version = "^1.8.0", optional = true} tree-sitter-languages = {version = "^1.8.0", optional = true}
azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true} azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true}
oracle-ads = {version = "^2.9.1", optional = true} oracle-ads = {version = "^2.9.1", optional = true}
zhipuai = {version = "^1.0.7", optional = true}
httpx = {version = "^0.24.1", optional = true} httpx = {version = "^0.24.1", optional = true}
elasticsearch = {version = "^8.12.0", optional = true} elasticsearch = {version = "^8.12.0", optional = true}
hdbcli = {version = "^2.19.21", optional = true} hdbcli = {version = "^2.19.21", optional = true}
@ -99,6 +98,8 @@ tidb-vector = {version = ">=0.0.3,<1.0.0", optional = true}
friendli-client = {version = "^1.2.4", optional = true} friendli-client = {version = "^1.2.4", optional = true}
premai = {version = "^0.3.25", optional = true} premai = {version = "^0.3.25", optional = true}
vdms = {version = "^0.0.20", optional = true} vdms = {version = "^0.0.20", optional = true}
httpx-sse = {version = "^0.4.0", optional = true}
pyjwt = {version = "^2.8.0", optional = true}
[tool.poetry.group.test] [tool.poetry.group.test]
optional = true optional = true
@ -262,7 +263,6 @@ extended_testing = [
"tree-sitter-languages", "tree-sitter-languages",
"azure-ai-documentintelligence", "azure-ai-documentintelligence",
"oracle-ads", "oracle-ads",
"zhipuai",
"httpx", "httpx",
"elasticsearch", "elasticsearch",
"hdbcli", "hdbcli",
@ -272,7 +272,9 @@ extended_testing = [
"cloudpickle", "cloudpickle",
"friendli-client", "friendli-client",
"premai", "premai",
"vdms" "vdms",
"httpx-sse",
"pyjwt"
] ]
[tool.ruff] [tool.ruff]

View File

@ -18,7 +18,7 @@ def test_default_call() -> None:
def test_model() -> None: def test_model() -> None:
"""Test model kwarg works.""" """Test model kwarg works."""
chat = ChatZhipuAI(model="chatglm_turbo") chat = ChatZhipuAI(model="glm-4")
response = chat(messages=[HumanMessage(content="Hello")]) response = chat(messages=[HumanMessage(content="Hello")])
assert isinstance(response, BaseMessage) assert isinstance(response, BaseMessage)
assert isinstance(response.content, str) assert isinstance(response.content, str)

View File

@ -1,10 +1,13 @@
"""Test ZhipuAI Chat API wrapper"""
import pytest import pytest
from langchain_community.chat_models.zhipuai import ChatZhipuAI from langchain_community.chat_models.zhipuai import ChatZhipuAI
@pytest.mark.requires("zhipuai") @pytest.mark.requires("httpx", "httpx_sse", "jwt")
def test_integration_initialization() -> None: def test_zhipuai_model_param() -> None:
chat = ChatZhipuAI(model="chatglm_turbo", streaming=False) llm = ChatZhipuAI(api_key="test", model="foo")
assert chat.model == "chatglm_turbo" assert llm.model_name == "foo"
assert chat.streaming is False llm = ChatZhipuAI(api_key="test", model_name="foo")
assert llm.model_name == "foo"