mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 05:08:20 +00:00
community[minor]: Update ChatZhipuAI to support GLM-4 model (#16695)
Description: Update `ChatZhipuAI` to support the latest `glm-4` model. Issue: N/A Dependencies: httpx, httpx-sse, PyJWT The previous `ChatZhipuAI` implementation requires the `zhipuai` package, and cannot call the latest GLM model. This is because - The old version `zhipuai==1.*` doesn't support the latest model. - `zhipuai==2.*` requires `pydantic V2`, which is incompatible with 'langchain-community'. This re-implementation invokes the GLM model by sending HTTP requests to [open.bigmodel.cn](https://open.bigmodel.cn/dev/api) via the `httpx` package, and uses the `httpx-sse` package to handle stream events. --------- Co-authored-by: zR <2448370773@qq.com>
This commit is contained in:
parent
d25b5b6f25
commit
a1f3e9f537
@ -32,7 +32,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"%pip install --upgrade --quiet zhipuai"
|
"%pip install --quiet httpx[socks]==0.24.1 httpx-sse PyJWT"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -85,9 +85,9 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"chat = ChatZhipuAI(\n",
|
"chat = ChatZhipuAI(\n",
|
||||||
" temperature=0.5,\n",
|
|
||||||
" api_key=zhipuai_api_key,\n",
|
" api_key=zhipuai_api_key,\n",
|
||||||
" model=\"chatglm_turbo\",\n",
|
" model=\"glm-4\",\n",
|
||||||
|
" temperature=0.5,\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -158,9 +158,9 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"streaming_chat = ChatZhipuAI(\n",
|
"streaming_chat = ChatZhipuAI(\n",
|
||||||
" temperature=0.5,\n",
|
|
||||||
" api_key=zhipuai_api_key,\n",
|
" api_key=zhipuai_api_key,\n",
|
||||||
" model=\"chatglm_turbo\",\n",
|
" model=\"glm-4\",\n",
|
||||||
|
" temperature=0.5,\n",
|
||||||
" streaming=True,\n",
|
" streaming=True,\n",
|
||||||
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
||||||
")"
|
")"
|
||||||
@ -211,9 +211,9 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"async_chat = ChatZhipuAI(\n",
|
"async_chat = ChatZhipuAI(\n",
|
||||||
" temperature=0.5,\n",
|
|
||||||
" api_key=zhipuai_api_key,\n",
|
" api_key=zhipuai_api_key,\n",
|
||||||
" model=\"chatglm_turbo\",\n",
|
" model=\"glm-4\",\n",
|
||||||
|
" temperature=0.5,\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -280,48 +280,6 @@
|
|||||||
" ),\n",
|
" ),\n",
|
||||||
"]"
|
"]"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 15,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"character_chat = ChatZhipuAI(\n",
|
|
||||||
" api_key=zhipuai_api_key,\n",
|
|
||||||
" meta=meta,\n",
|
|
||||||
" model=\"characterglm\",\n",
|
|
||||||
" streaming=True,\n",
|
|
||||||
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 16,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Okay, great! I'm looking forward to it."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"AIMessage(content=\"Okay, great! I'm looking forward to it.\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 16,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"character_chat(messages)"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -340,10 +298,9 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.4"
|
"version": "3.9.18"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 4
|
"nbformat_minor": 4
|
||||||
}
|
}
|
||||||
|
|
@ -1,45 +1,158 @@
|
|||||||
"""ZHIPU AI chat models wrapper."""
|
"""ZhipuAI chat models wrapper."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
import time
|
||||||
from typing import Any, Dict, Iterator, List, Optional, cast
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
from langchain_core.language_models.chat_models import (
|
from langchain_core.language_models.chat_models import (
|
||||||
BaseChatModel,
|
BaseChatModel,
|
||||||
|
agenerate_from_stream,
|
||||||
generate_from_stream,
|
generate_from_stream,
|
||||||
)
|
)
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
ChatMessage,
|
||||||
|
ChatMessageChunk,
|
||||||
|
HumanMessage,
|
||||||
|
HumanMessageChunk,
|
||||||
|
SystemMessage,
|
||||||
|
SystemMessageChunk,
|
||||||
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
API_TOKEN_TTL_SECONDS = 3 * 60
|
||||||
class ref(BaseModel):
|
ZHIPUAI_API_BASE = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
|
||||||
"""Reference used in CharacterGLM."""
|
|
||||||
|
|
||||||
enable: bool = Field(True)
|
|
||||||
search_query: str = Field("")
|
|
||||||
|
|
||||||
|
|
||||||
class meta(BaseModel):
|
@contextmanager
|
||||||
"""Metadata used in CharacterGLM."""
|
def connect_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator:
|
||||||
|
from httpx_sse import EventSource
|
||||||
|
|
||||||
user_info: str = Field("")
|
with client.stream(method, url, **kwargs) as response:
|
||||||
bot_info: str = Field("")
|
yield EventSource(response)
|
||||||
bot_name: str = Field("")
|
|
||||||
user_name: str = Field("User")
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def aconnect_sse(
|
||||||
|
client: Any, method: str, url: str, **kwargs: Any
|
||||||
|
) -> AsyncIterator:
|
||||||
|
from httpx_sse import EventSource
|
||||||
|
|
||||||
|
async with client.stream(method, url, **kwargs) as response:
|
||||||
|
yield EventSource(response)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_jwt_token(api_key: str) -> str:
|
||||||
|
"""Gets JWT token for ZhipuAI API, see 'https://open.bigmodel.cn/dev/api#nosdk'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: The API key for ZhipuAI API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The JWT token.
|
||||||
|
"""
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
try:
|
||||||
|
id, secret = api_key.split(".")
|
||||||
|
except ValueError as err:
|
||||||
|
raise ValueError(f"Invalid API key: {api_key}") from err
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"api_key": id,
|
||||||
|
"exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
|
||||||
|
"timestamp": int(round(time.time() * 1000)),
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt.encode(
|
||||||
|
payload,
|
||||||
|
secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_dict_to_message(dct: Dict[str, Any]) -> BaseMessage:
|
||||||
|
role = dct.get("role")
|
||||||
|
content = dct.get("content", "")
|
||||||
|
if role == "system":
|
||||||
|
return SystemMessage(content=content)
|
||||||
|
if role == "user":
|
||||||
|
return HumanMessage(content=content)
|
||||||
|
if role == "assistant":
|
||||||
|
additional_kwargs = {}
|
||||||
|
tool_calls = dct.get("tool_calls", None)
|
||||||
|
if tool_calls is not None:
|
||||||
|
additional_kwargs["tool_calls"] = tool_calls
|
||||||
|
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
return ChatMessage(role=role, content=content)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
|
||||||
|
"""Convert a LangChain message to a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The LangChain message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The dictionary.
|
||||||
|
"""
|
||||||
|
message_dict: Dict[str, Any]
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": message.role, "content": message.content}
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Got unknown type '{message.__class__.__name__}'.")
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_delta_to_message_chunk(
|
||||||
|
dct: Dict[str, Any], default_class: Type[BaseMessageChunk]
|
||||||
|
) -> BaseMessageChunk:
|
||||||
|
role = dct.get("role")
|
||||||
|
content = dct.get("content", "")
|
||||||
|
additional_kwargs = {}
|
||||||
|
tool_calls = dct.get("tool_call", None)
|
||||||
|
if tool_calls is not None:
|
||||||
|
additional_kwargs["tool_calls"] = tool_calls
|
||||||
|
|
||||||
|
if role == "system" or default_class == SystemMessageChunk:
|
||||||
|
return SystemMessageChunk(content=content)
|
||||||
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
|
return HumanMessageChunk(content=content)
|
||||||
|
if role == "assistant" or default_class == AIMessageChunk:
|
||||||
|
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
if role or default_class == ChatMessageChunk:
|
||||||
|
return ChatMessageChunk(content=content, role=role)
|
||||||
|
return default_class(content=content)
|
||||||
|
|
||||||
|
|
||||||
class ChatZhipuAI(BaseChatModel):
|
class ChatZhipuAI(BaseChatModel):
|
||||||
"""
|
"""
|
||||||
`ZHIPU AI` large language chat models API.
|
`ZhipuAI` large language chat models API.
|
||||||
|
|
||||||
To use, you should have the ``zhipuai`` python package installed.
|
To use, you should have the ``PyJWT`` python package installed.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -49,98 +162,11 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
zhipuai_chat = ChatZhipuAI(
|
zhipuai_chat = ChatZhipuAI(
|
||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
api_key="your-api-key",
|
api_key="your-api-key",
|
||||||
model="chatglm_turbo",
|
model="glm-4"
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
zhipuai: Any
|
|
||||||
zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
|
||||||
"""Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided."""
|
|
||||||
|
|
||||||
model: str = Field("chatglm_turbo")
|
|
||||||
"""
|
|
||||||
Model name to use.
|
|
||||||
-chatglm_turbo:
|
|
||||||
According to the input of natural language instructions to complete a
|
|
||||||
variety of language tasks, it is recommended to use SSE or asynchronous
|
|
||||||
call request interface.
|
|
||||||
-characterglm:
|
|
||||||
It supports human-based role-playing, ultra-long multi-round memory,
|
|
||||||
and thousands of character dialogues. It is widely used in anthropomorphic
|
|
||||||
dialogues or game scenes such as emotional accompaniments, game intelligent
|
|
||||||
NPCS, Internet celebrities/stars/movie and TV series IP clones, digital
|
|
||||||
people/virtual anchors, and text adventure games.
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: float = Field(0.95)
|
|
||||||
"""
|
|
||||||
What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot
|
|
||||||
be equal to 0.
|
|
||||||
The larger the value, the more random and creative the output; The smaller
|
|
||||||
the value, the more stable or certain the output will be.
|
|
||||||
You are advised to adjust top_p or temperature parameters based on application
|
|
||||||
scenarios, but do not adjust the two parameters at the same time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
top_p: float = Field(0.7)
|
|
||||||
"""
|
|
||||||
Another method of sampling temperature is called nuclear sampling. The value
|
|
||||||
ranges from 0.0 to 1.0 and cannot be equal to 0 or 1.
|
|
||||||
The model considers the results with top_p probability quality tokens.
|
|
||||||
For example, 0.1 means that the model decoder only considers tokens from the
|
|
||||||
top 10% probability of the candidate set.
|
|
||||||
You are advised to adjust top_p or temperature parameters based on application
|
|
||||||
scenarios, but do not adjust the two parameters at the same time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
request_id: Optional[str] = Field(None)
|
|
||||||
"""
|
|
||||||
Parameter transmission by the client must ensure uniqueness; A unique
|
|
||||||
identifier used to distinguish each request, which is generated by default
|
|
||||||
by the platform when the client does not transmit it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
streaming: bool = Field(False)
|
|
||||||
"""Whether to stream the results or not."""
|
|
||||||
|
|
||||||
incremental: bool = Field(True)
|
|
||||||
"""
|
|
||||||
When invoked by the SSE interface, it is used to control whether the content
|
|
||||||
is returned incremented or full each time.
|
|
||||||
If this parameter is not provided, the value is returned incremented by default.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return_type: str = Field("json_string")
|
|
||||||
"""
|
|
||||||
This parameter is used to control the type of content returned each time.
|
|
||||||
- json_string Returns a standard JSON string.
|
|
||||||
- text Returns the original text content.
|
|
||||||
"""
|
|
||||||
|
|
||||||
ref: Optional[ref] = Field(None)
|
|
||||||
"""
|
|
||||||
This parameter is used to control the reference of external information
|
|
||||||
during the request.
|
|
||||||
Currently, this parameter is used to control whether to reference external
|
|
||||||
information.
|
|
||||||
If this field is empty or absent, the search and parameter passing format
|
|
||||||
is enabled by default.
|
|
||||||
{"enable": "true", "search_query": "history "}
|
|
||||||
"""
|
|
||||||
|
|
||||||
meta: Optional[meta] = Field(None)
|
|
||||||
"""Used in CharacterGLM"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
|
||||||
return {"model_name": self.model}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _llm_type(self) -> str:
|
|
||||||
"""Return the type of chat model."""
|
|
||||||
return "zhipuai"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_secrets(self) -> Dict[str, str]:
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
return {"zhipuai_api_key": "ZHIPUAI_API_KEY"}
|
return {"zhipuai_api_key": "ZHIPUAI_API_KEY"}
|
||||||
@ -154,93 +180,109 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
def lc_attributes(self) -> Dict[str, Any]:
|
def lc_attributes(self) -> Dict[str, Any]:
|
||||||
attributes: Dict[str, Any] = {}
|
attributes: Dict[str, Any] = {}
|
||||||
|
|
||||||
if self.model:
|
if self.zhipuai_api_base:
|
||||||
attributes["model"] = self.model
|
attributes["zhipuai_api_base"] = self.zhipuai_api_base
|
||||||
|
|
||||||
if self.streaming:
|
|
||||||
attributes["streaming"] = self.streaming
|
|
||||||
|
|
||||||
if self.return_type:
|
|
||||||
attributes["return_type"] = self.return_type
|
|
||||||
|
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
@property
|
||||||
super().__init__(*args, **kwargs)
|
def _llm_type(self) -> str:
|
||||||
try:
|
"""Return the type of chat model."""
|
||||||
import zhipuai
|
return "zhipuai-chat"
|
||||||
|
|
||||||
self.zhipuai = zhipuai
|
@property
|
||||||
self.zhipuai.api_key = self.zhipuai_api_key
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
except ImportError:
|
"""Get the default parameters for calling OpenAI API."""
|
||||||
raise RuntimeError(
|
params = {
|
||||||
"Could not import zhipuai package. "
|
"model": self.model_name,
|
||||||
"Please install it via 'pip install zhipuai'"
|
"stream": self.streaming,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
}
|
||||||
|
if self.max_tokens is not None:
|
||||||
|
params["max_tokens"] = self.max_tokens
|
||||||
|
return params
|
||||||
|
|
||||||
|
# client:
|
||||||
|
zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
||||||
|
"""Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided."""
|
||||||
|
zhipuai_api_base: Optional[str] = Field(default=None, alias="api_base")
|
||||||
|
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||||
|
emulator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_name: Optional[str] = Field(default="glm-4", alias="model")
|
||||||
|
"""
|
||||||
|
Model name to use, see 'https://open.bigmodel.cn/dev/api#language'.
|
||||||
|
or you can use any finetune model of glm series.
|
||||||
|
"""
|
||||||
|
|
||||||
|
temperature: float = 0.95
|
||||||
|
"""
|
||||||
|
What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot
|
||||||
|
be equal to 0.
|
||||||
|
The larger the value, the more random and creative the output; The smaller
|
||||||
|
the value, the more stable or certain the output will be.
|
||||||
|
You are advised to adjust top_p or temperature parameters based on application
|
||||||
|
scenarios, but do not adjust the two parameters at the same time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
top_p: float = 0.7
|
||||||
|
"""
|
||||||
|
Another method of sampling temperature is called nuclear sampling. The value
|
||||||
|
ranges from 0.0 to 1.0 and cannot be equal to 0 or 1.
|
||||||
|
The model considers the results with top_p probability quality tokens.
|
||||||
|
For example, 0.1 means that the model decoder only considers tokens from the
|
||||||
|
top 10% probability of the candidate set.
|
||||||
|
You are advised to adjust top_p or temperature parameters based on application
|
||||||
|
scenarios, but do not adjust the two parameters at the same time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results or not."""
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
"""Maximum number of tokens to generate."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
values["zhipuai_api_key"] = get_from_dict_or_env(
|
||||||
|
values, "zhipuai_api_key", "ZHIPUAI_API_KEY"
|
||||||
|
)
|
||||||
|
values["zhipuai_api_base"] = get_from_dict_or_env(
|
||||||
|
values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, prompt: Any) -> Any: # type: ignore[override]
|
return values
|
||||||
if self.model == "chatglm_turbo":
|
|
||||||
return self.zhipuai.model_api.invoke(
|
|
||||||
model=self.model,
|
|
||||||
prompt=prompt,
|
|
||||||
top_p=self.top_p,
|
|
||||||
temperature=self.temperature,
|
|
||||||
request_id=self.request_id,
|
|
||||||
return_type=self.return_type,
|
|
||||||
)
|
|
||||||
elif self.model == "characterglm":
|
|
||||||
_meta = cast(meta, self.meta).dict()
|
|
||||||
return self.zhipuai.model_api.invoke(
|
|
||||||
model=self.model,
|
|
||||||
meta=_meta,
|
|
||||||
prompt=prompt,
|
|
||||||
request_id=self.request_id,
|
|
||||||
return_type=self.return_type,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def sse_invoke(self, prompt: Any) -> Any:
|
def _create_message_dicts(
|
||||||
if self.model == "chatglm_turbo":
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||||
return self.zhipuai.model_api.sse_invoke(
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
model=self.model,
|
params = self._default_params
|
||||||
prompt=prompt,
|
if stop is not None:
|
||||||
top_p=self.top_p,
|
params["stop"] = stop
|
||||||
temperature=self.temperature,
|
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||||
request_id=self.request_id,
|
return message_dicts, params
|
||||||
return_type=self.return_type,
|
|
||||||
incremental=self.incremental,
|
|
||||||
)
|
|
||||||
elif self.model == "characterglm":
|
|
||||||
_meta = cast(meta, self.meta).dict()
|
|
||||||
return self.zhipuai.model_api.sse_invoke(
|
|
||||||
model=self.model,
|
|
||||||
prompt=prompt,
|
|
||||||
meta=_meta,
|
|
||||||
request_id=self.request_id,
|
|
||||||
return_type=self.return_type,
|
|
||||||
incremental=self.incremental,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def async_invoke(self, prompt: Any) -> Any:
|
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
|
||||||
loop = asyncio.get_running_loop()
|
generations = []
|
||||||
partial_func = partial(
|
if not isinstance(response, dict):
|
||||||
self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt
|
response = response.dict()
|
||||||
|
for res in response["choices"]:
|
||||||
|
message = _convert_dict_to_message(res["message"])
|
||||||
|
generation_info = dict(finish_reason=res.get("finish_reason"))
|
||||||
|
generations.append(
|
||||||
|
ChatGeneration(message=message, generation_info=generation_info)
|
||||||
)
|
)
|
||||||
response = await loop.run_in_executor(
|
token_usage = response.get("usage", {})
|
||||||
None,
|
llm_output = {
|
||||||
partial_func,
|
"token_usage": token_usage,
|
||||||
)
|
"model_name": self.model_name,
|
||||||
return response
|
}
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
async def async_invoke_result(self, task_id: Any) -> Any:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
response = await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
self.zhipuai.model_api.query_async_invoke_result,
|
|
||||||
task_id,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
@ -251,86 +293,163 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Generate a chat response."""
|
"""Generate a chat response."""
|
||||||
prompt: List = []
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message, AIMessage):
|
|
||||||
role = "assistant"
|
|
||||||
else: # For both HumanMessage and SystemMessage, role is 'user'
|
|
||||||
role = "user"
|
|
||||||
|
|
||||||
prompt.append({"role": role, "content": message.content})
|
|
||||||
|
|
||||||
should_stream = stream if stream is not None else self.streaming
|
should_stream = stream if stream is not None else self.streaming
|
||||||
if not should_stream:
|
if should_stream:
|
||||||
response = self.invoke(prompt)
|
|
||||||
|
|
||||||
if response["code"] != 200:
|
|
||||||
raise RuntimeError(response)
|
|
||||||
|
|
||||||
content = response["data"]["choices"][0]["content"]
|
|
||||||
return ChatResult(
|
|
||||||
generations=[ChatGeneration(message=AIMessage(content=content))]
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
stream_iter = self._stream(
|
stream_iter = self._stream(
|
||||||
prompt=prompt,
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
stop=stop,
|
|
||||||
run_manager=run_manager,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
return generate_from_stream(stream_iter)
|
return generate_from_stream(stream_iter)
|
||||||
|
|
||||||
async def _agenerate( # type: ignore[override]
|
if self.zhipuai_api_key is None:
|
||||||
|
raise ValueError("Did not find zhipuai_api_key.")
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
payload = {
|
||||||
|
**params,
|
||||||
|
**kwargs,
|
||||||
|
"messages": message_dicts,
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
with httpx.Client(headers=headers) as client:
|
||||||
|
response = client.post(self.zhipuai_api_base, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
return self._create_chat_result(response.json())
|
||||||
|
|
||||||
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
stream: Optional[bool] = False,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
"""Asynchronously generate a chat response."""
|
|
||||||
|
|
||||||
prompt = []
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message, AIMessage):
|
|
||||||
role = "assistant"
|
|
||||||
else: # For both HumanMessage and SystemMessage, role is 'user'
|
|
||||||
role = "user"
|
|
||||||
|
|
||||||
prompt.append({"role": role, "content": message.content})
|
|
||||||
|
|
||||||
invoke_response = await self.async_invoke(prompt)
|
|
||||||
task_id = invoke_response["data"]["task_id"]
|
|
||||||
|
|
||||||
response = await self.async_invoke_result(task_id)
|
|
||||||
while response["data"]["task_status"] != "SUCCESS":
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
response = await self.async_invoke_result(task_id)
|
|
||||||
|
|
||||||
content = response["data"]["choices"][0]["content"]
|
|
||||||
content = json.loads(content)
|
|
||||||
return ChatResult(
|
|
||||||
generations=[ChatGeneration(message=AIMessage(content=content))]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _stream( # type: ignore[override]
|
|
||||||
self,
|
|
||||||
prompt: List[Dict[str, str]],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
"""Stream the chat response in chunks."""
|
"""Stream the chat response in chunks."""
|
||||||
response = self.sse_invoke(prompt)
|
if self.zhipuai_api_key is None:
|
||||||
|
raise ValueError("Did not find zhipuai_api_key.")
|
||||||
|
if self.zhipuai_api_base is None:
|
||||||
|
raise ValueError("Did not find zhipu_api_base.")
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
|
||||||
|
headers = {
|
||||||
|
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
for r in response.events():
|
default_chunk_class = AIMessageChunk
|
||||||
if r.event == "add":
|
import httpx
|
||||||
delta = r.data
|
|
||||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
with httpx.Client(headers=headers) as client:
|
||||||
if run_manager:
|
with connect_sse(
|
||||||
run_manager.on_llm_new_token(delta, chunk=chunk)
|
client, "POST", self.zhipuai_api_base, json=payload
|
||||||
|
) as event_source:
|
||||||
|
for sse in event_source.iter_sse():
|
||||||
|
chunk = json.loads(sse.data)
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
choice = chunk["choices"][0]
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
finish_reason = choice.get("finish_reason", None)
|
||||||
|
|
||||||
|
generation_info = (
|
||||||
|
{"finish_reason": finish_reason}
|
||||||
|
if finish_reason is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
chunk = ChatGenerationChunk(
|
||||||
|
message=chunk, generation_info=generation_info
|
||||||
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
if finish_reason is not None:
|
||||||
|
break
|
||||||
|
|
||||||
elif r.event == "error":
|
async def _agenerate(
|
||||||
raise ValueError(f"Error from ZhipuAI API response: {r.data}")
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
should_stream = stream if stream is not None else self.streaming
|
||||||
|
if should_stream:
|
||||||
|
stream_iter = self._astream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return await agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
if self.zhipuai_api_key is None:
|
||||||
|
raise ValueError("Did not find zhipuai_api_key.")
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
payload = {
|
||||||
|
**params,
|
||||||
|
**kwargs,
|
||||||
|
"messages": message_dicts,
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(headers=headers) as client:
|
||||||
|
response = await client.post(self.zhipuai_api_base, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
return self._create_chat_result(response.json())
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
if self.zhipuai_api_key is None:
|
||||||
|
raise ValueError("Did not find zhipuai_api_key.")
|
||||||
|
if self.zhipuai_api_base is None:
|
||||||
|
raise ValueError("Did not find zhipu_api_base.")
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
|
||||||
|
headers = {
|
||||||
|
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(headers=headers) as client:
|
||||||
|
async with aconnect_sse(
|
||||||
|
client, "POST", self.zhipuai_api_base, json=payload
|
||||||
|
) as event_source:
|
||||||
|
async for sse in event_source.aiter_sse():
|
||||||
|
chunk = json.loads(sse.data)
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
choice = chunk["choices"][0]
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
finish_reason = choice.get("finish_reason", None)
|
||||||
|
|
||||||
|
generation_info = (
|
||||||
|
{"finish_reason": finish_reason}
|
||||||
|
if finish_reason is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
chunk = ChatGenerationChunk(
|
||||||
|
message=chunk, generation_info=generation_info
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
if finish_reason is not None:
|
||||||
|
break
|
||||||
|
32
libs/community/poetry.lock
generated
32
libs/community/poetry.lock
generated
@ -1407,17 +1407,6 @@ mlflow-skinny = ">=2.4.0,<3"
|
|||||||
protobuf = ">=3.12.0,<5"
|
protobuf = ">=3.12.0,<5"
|
||||||
requests = ">=2"
|
requests = ">=2"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "dataclasses"
|
|
||||||
version = "0.6"
|
|
||||||
description = "A backport of the dataclasses module for Python 3.6"
|
|
||||||
optional = true
|
|
||||||
python-versions = "*"
|
|
||||||
files = [
|
|
||||||
{file = "dataclasses-0.6-py3-none-any.whl", hash = "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f"},
|
|
||||||
{file = "dataclasses-0.6.tar.gz", hash = "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dataclasses-json"
|
name = "dataclasses-json"
|
||||||
version = "0.6.4"
|
version = "0.6.4"
|
||||||
@ -9229,23 +9218,6 @@ files = [
|
|||||||
idna = ">=2.0"
|
idna = ">=2.0"
|
||||||
multidict = ">=4.0"
|
multidict = ">=4.0"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "zhipuai"
|
|
||||||
version = "1.0.7"
|
|
||||||
description = "A SDK library for accessing big model apis from ZhipuAI"
|
|
||||||
optional = true
|
|
||||||
python-versions = ">=3.6"
|
|
||||||
files = [
|
|
||||||
{file = "zhipuai-1.0.7-py3-none-any.whl", hash = "sha256:360c01b8c2698f366061452e86d5a36a5ff68a576ea33940da98e4806f232530"},
|
|
||||||
{file = "zhipuai-1.0.7.tar.gz", hash = "sha256:b80f699543d83cce8648acf1ce32bc2725d1c1c443baffa5882abc2cc704d581"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
cachetools = "*"
|
|
||||||
dataclasses = "*"
|
|
||||||
PyJWT = "*"
|
|
||||||
requests = "*"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zipp"
|
name = "zipp"
|
||||||
version = "3.17.0"
|
version = "3.17.0"
|
||||||
@ -9263,9 +9235,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
cli = ["typer"]
|
cli = ["typer"]
|
||||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict", "zhipuai"]
|
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "45da04abac45743972d1edf62d08d9abaa2bebb473b794e0a0d6f1fdc87773f9"
|
content-hash = "67c38c029bb59d45fd0f84a5d48c44f64f1301d6be07f419615d08ba8671a2a7"
|
||||||
|
@ -88,7 +88,6 @@ tree-sitter = {version = "^0.20.2", optional = true}
|
|||||||
tree-sitter-languages = {version = "^1.8.0", optional = true}
|
tree-sitter-languages = {version = "^1.8.0", optional = true}
|
||||||
azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true}
|
azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true}
|
||||||
oracle-ads = {version = "^2.9.1", optional = true}
|
oracle-ads = {version = "^2.9.1", optional = true}
|
||||||
zhipuai = {version = "^1.0.7", optional = true}
|
|
||||||
httpx = {version = "^0.24.1", optional = true}
|
httpx = {version = "^0.24.1", optional = true}
|
||||||
elasticsearch = {version = "^8.12.0", optional = true}
|
elasticsearch = {version = "^8.12.0", optional = true}
|
||||||
hdbcli = {version = "^2.19.21", optional = true}
|
hdbcli = {version = "^2.19.21", optional = true}
|
||||||
@ -99,6 +98,8 @@ tidb-vector = {version = ">=0.0.3,<1.0.0", optional = true}
|
|||||||
friendli-client = {version = "^1.2.4", optional = true}
|
friendli-client = {version = "^1.2.4", optional = true}
|
||||||
premai = {version = "^0.3.25", optional = true}
|
premai = {version = "^0.3.25", optional = true}
|
||||||
vdms = {version = "^0.0.20", optional = true}
|
vdms = {version = "^0.0.20", optional = true}
|
||||||
|
httpx-sse = {version = "^0.4.0", optional = true}
|
||||||
|
pyjwt = {version = "^2.8.0", optional = true}
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
[tool.poetry.group.test]
|
||||||
optional = true
|
optional = true
|
||||||
@ -262,7 +263,6 @@ extended_testing = [
|
|||||||
"tree-sitter-languages",
|
"tree-sitter-languages",
|
||||||
"azure-ai-documentintelligence",
|
"azure-ai-documentintelligence",
|
||||||
"oracle-ads",
|
"oracle-ads",
|
||||||
"zhipuai",
|
|
||||||
"httpx",
|
"httpx",
|
||||||
"elasticsearch",
|
"elasticsearch",
|
||||||
"hdbcli",
|
"hdbcli",
|
||||||
@ -272,7 +272,9 @@ extended_testing = [
|
|||||||
"cloudpickle",
|
"cloudpickle",
|
||||||
"friendli-client",
|
"friendli-client",
|
||||||
"premai",
|
"premai",
|
||||||
"vdms"
|
"vdms",
|
||||||
|
"httpx-sse",
|
||||||
|
"pyjwt"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
@ -18,7 +18,7 @@ def test_default_call() -> None:
|
|||||||
|
|
||||||
def test_model() -> None:
|
def test_model() -> None:
|
||||||
"""Test model kwarg works."""
|
"""Test model kwarg works."""
|
||||||
chat = ChatZhipuAI(model="chatglm_turbo")
|
chat = ChatZhipuAI(model="glm-4")
|
||||||
response = chat(messages=[HumanMessage(content="Hello")])
|
response = chat(messages=[HumanMessage(content="Hello")])
|
||||||
assert isinstance(response, BaseMessage)
|
assert isinstance(response, BaseMessage)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
|
"""Test ZhipuAI Chat API wrapper"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_community.chat_models.zhipuai import ChatZhipuAI
|
from langchain_community.chat_models.zhipuai import ChatZhipuAI
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("zhipuai")
|
@pytest.mark.requires("httpx", "httpx_sse", "jwt")
|
||||||
def test_integration_initialization() -> None:
|
def test_zhipuai_model_param() -> None:
|
||||||
chat = ChatZhipuAI(model="chatglm_turbo", streaming=False)
|
llm = ChatZhipuAI(api_key="test", model="foo")
|
||||||
assert chat.model == "chatglm_turbo"
|
assert llm.model_name == "foo"
|
||||||
assert chat.streaming is False
|
llm = ChatZhipuAI(api_key="test", model_name="foo")
|
||||||
|
assert llm.model_name == "foo"
|
||||||
|
Loading…
Reference in New Issue
Block a user