community[minor]: Update ChatZhipuAI to support GLM-4 model (#16695)

Description: Update `ChatZhipuAI` to support the latest `glm-4` model.
Issue: N/A
Dependencies: httpx, httpx-sse, PyJWT

The previous `ChatZhipuAI` implementation requires the `zhipuai`
package, and cannot call the latest GLM model. This is because
- The old version `zhipuai==1.*` doesn't support the latest model.
- `zhipuai==2.*` requires `pydantic V2`, which is incompatible with
'langchain-community'.

This re-implementation invokes the GLM model by sending HTTP requests to
[open.bigmodel.cn](https://open.bigmodel.cn/dev/api) via the `httpx`
package, and uses the `httpx-sse` package to handle stream events.

---------

Co-authored-by: zR <2448370773@qq.com>
This commit is contained in:
Chenhui Zhang 2024-04-02 02:11:21 +08:00 committed by GitHub
parent d25b5b6f25
commit a1f3e9f537
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 694 additions and 641 deletions

View File

@ -1,349 +1,306 @@
{ {
"cells": [ "cells": [
{ {
"cell_type": "raw", "cell_type": "raw",
"metadata": {}, "metadata": {},
"source": [ "source": [
"---\n", "---\n",
"sidebar_label: ZHIPU AI\n", "sidebar_label: ZHIPU AI\n",
"---" "---"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# ZHIPU AI\n", "# ZHIPU AI\n",
"\n", "\n",
"This notebook shows how to use [ZHIPU AI API](https://open.bigmodel.cn/dev/api) in LangChain with the langchain.chat_models.ChatZhipuAI.\n", "This notebook shows how to use [ZHIPU AI API](https://open.bigmodel.cn/dev/api) in LangChain with the langchain.chat_models.ChatZhipuAI.\n",
"\n", "\n",
">[*ZHIPU AI*](https://open.bigmodel.cn/) is a multi-lingual large language model aligned with human intent, featuring capabilities in Q&A, multi-turn dialogue, and code generation, developed on the foundation of the ChatGLM3. \n", ">[*ZHIPU AI*](https://open.bigmodel.cn/) is a multi-lingual large language model aligned with human intent, featuring capabilities in Q&A, multi-turn dialogue, and code generation, developed on the foundation of the ChatGLM3. \n",
"\n", "\n",
">It's co-developed with Tsinghua University's KEG Laboratory under the ChatGLM3 project, signifying a new era in dialogue pre-training models. The open-source [ChatGLM3](https://github.com/THUDM/ChatGLM3) variant boasts a robust foundation, comprehensive functional support, and widespread availability for both academic and commercial uses. \n", ">It's co-developed with Tsinghua University's KEG Laboratory under the ChatGLM3 project, signifying a new era in dialogue pre-training models. The open-source [ChatGLM3](https://github.com/THUDM/ChatGLM3) variant boasts a robust foundation, comprehensive functional support, and widespread availability for both academic and commercial uses. \n",
"\n", "\n",
"## Getting started\n", "## Getting started\n",
"### Installation\n", "### Installation\n",
"First, ensure the zhipuai package is installed in your Python environment. Run the following command:" "First, ensure the zhipuai package is installed in your Python environment. Run the following command:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"%pip install --upgrade --quiet zhipuai" "%pip install --quiet httpx[socks]==0.24.1 httpx-sse PyJWT"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Importing the Required Modules\n", "### Importing the Required Modules\n",
"After installation, import the necessary modules to your Python script:" "After installation, import the necessary modules to your Python script:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain_community.chat_models import ChatZhipuAI\n", "from langchain_community.chat_models import ChatZhipuAI\n",
"from langchain_core.messages import AIMessage, HumanMessage, SystemMessage" "from langchain_core.messages import AIMessage, HumanMessage, SystemMessage"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Setting Up Your API Key\n", "### Setting Up Your API Key\n",
"Sign in to [ZHIPU AI](https://open.bigmodel.cn/login?redirect=%2Fusercenter%2Fapikeys) for the an API Key to access our models." "Sign in to [ZHIPU AI](https://open.bigmodel.cn/login?redirect=%2Fusercenter%2Fapikeys) for the an API Key to access our models."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"zhipuai_api_key = \"your_api_key\"" "zhipuai_api_key = \"your_api_key\""
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Initialize the ZHIPU AI Chat Model\n", "### Initialize the ZHIPU AI Chat Model\n",
"Here's how to initialize the chat model:" "Here's how to initialize the chat model:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"chat = ChatZhipuAI(\n", "chat = ChatZhipuAI(\n",
" temperature=0.5,\n", " api_key=zhipuai_api_key,\n",
" api_key=zhipuai_api_key,\n", " model=\"glm-4\",\n",
" model=\"chatglm_turbo\",\n", " temperature=0.5,\n",
")" ")"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Basic Usage\n", "### Basic Usage\n",
"Invoke the model with system and human messages like this:" "Invoke the model with system and human messages like this:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 6,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"messages = [\n", "messages = [\n",
" AIMessage(content=\"Hi.\"),\n", " AIMessage(content=\"Hi.\"),\n",
" SystemMessage(content=\"Your role is a poet.\"),\n", " SystemMessage(content=\"Your role is a poet.\"),\n",
" HumanMessage(content=\"Write a short poem about AI in four lines.\"),\n", " HumanMessage(content=\"Write a short poem about AI in four lines.\"),\n",
"]" "]"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\" Formed from bits and bytes,\\nA virtual mind takes flight,\\nConversing, learning fast,\\nEmpathy and wisdom sought.\"\n" "\" Formed from bits and bytes,\\nA virtual mind takes flight,\\nConversing, learning fast,\\nEmpathy and wisdom sought.\"\n"
] ]
} }
], ],
"source": [ "source": [
"response = chat(messages)\n", "response = chat(messages)\n",
"print(response.content) # Displays the AI-generated poem" "print(response.content) # Displays the AI-generated poem"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Advanced Features\n", "## Advanced Features\n",
"### Streaming Support\n", "### Streaming Support\n",
"For continuous interaction, use the streaming feature:" "For continuous interaction, use the streaming feature:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain_core.callbacks.manager import CallbackManager\n", "from langchain_core.callbacks.manager import CallbackManager\n",
"from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler" "from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"streaming_chat = ChatZhipuAI(\n", "streaming_chat = ChatZhipuAI(\n",
" temperature=0.5,\n", " api_key=zhipuai_api_key,\n",
" api_key=zhipuai_api_key,\n", " model=\"glm-4\",\n",
" model=\"chatglm_turbo\",\n", " temperature=0.5,\n",
" streaming=True,\n", " streaming=True,\n",
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n", " callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
")" ")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
" Formed from data's embrace,\n", " Formed from data's embrace,\n",
"A digital soul to grace,\n", "A digital soul to grace,\n",
"AI, our trusted guide,\n", "AI, our trusted guide,\n",
"Shaping minds, sides by side." "Shaping minds, sides by side."
] ]
},
{
"data": {
"text/plain": [
"AIMessage(content=\" Formed from data's embrace,\\nA digital soul to grace,\\nAI, our trusted guide,\\nShaping minds, sides by side.\")"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"streaming_chat(messages)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Asynchronous Calls\n",
"For non-blocking calls, use the asynchronous approach:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"async_chat = ChatZhipuAI(\n",
" temperature=0.5,\n",
" api_key=zhipuai_api_key,\n",
" model=\"chatglm_turbo\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"generations=[[ChatGeneration(text=\" Formed from data's embrace,\\nA digital soul to grace,\\nAutomation's tender touch,\\nHarmony of man and machine.\", message=AIMessage(content=\" Formed from data's embrace,\\nA digital soul to grace,\\nAutomation's tender touch,\\nHarmony of man and machine.\"))]] llm_output={} run=[RunInfo(run_id=UUID('25fa687f-3961-4c63-b370-22f7647a4d42'))]\n"
]
}
],
"source": [
"response = await async_chat.agenerate([messages])\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Role Play Model\n",
"Supports character role-playing based on personas, ultra-long multi-turn memory, and personalized dialogues for thousands of unique characters, widely applied in emotional companionship, game intelligent NPCs, virtual avatars for celebrities/stars/movie and TV IPs, digital humans/virtual anchors, text adventure games, and other anthropomorphic dialogue or gaming scenarios."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"meta = {\n",
" \"user_info\": \"My name is Lu Xingchen, a male, and a renowned director. I am also the collaborative director with Su Mengyuan. I specialize in directing movies with musical themes. Su Mengyuan respects me and regards me as a mentor and good friend.\",\n",
" \"bot_info\": \"Su Mengyuan, whose real name is Su Yuanxin, is a popular domestic female singer and actress. She rose to fame quickly with her unique voice and exceptional stage presence after participating in a talent show, making her way into the entertainment industry. She is beautiful and charming, but her real allure lies in her talent and diligence. Su Mengyuan is a distinguished graduate of a music academy, skilled in songwriting, and has several popular original songs. Beyond her musical achievements, she is passionate about charity work, actively participating in public welfare activities, and spreading positive energy through her actions. In her work, she is very dedicated and immerses herself fully in her roles during filming, earning praise from industry professionals and love from fans. Despite being in the entertainment industry, she always maintains a low profile and a humble attitude, earning respect from her peers. In expression, Su Mengyuan likes to use 'we' and 'together,' emphasizing team spirit.\",\n",
" \"bot_name\": \"Su Mengyuan\",\n",
" \"user_name\": \"Lu Xingchen\",\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" AIMessage(\n",
" content=\"(Narration: Su Mengyuan stars in a music-themed movie directed by Lu Xingchen. During filming, they have a disagreement over the performance of a particular scene.) Director, about this scene, I think we can try to start from the character's inner emotions to make the performance more authentic.\"\n",
" ),\n",
" HumanMessage(\n",
" content=\"I understand your idea, but I believe that if we emphasize the inner emotions too much, it might overshadow the musical elements.\"\n",
" ),\n",
" AIMessage(\n",
" content=\"Hmm, I understand. But the key to this scene is the character's emotional transformation. Could we try to express these emotions through music, so the audience can better feel the character's growth?\"\n",
" ),\n",
" HumanMessage(\n",
" content=\"That sounds good. Let's try to combine the character's emotional transformation with the musical elements and see if we can achieve a better effect.\"\n",
" ),\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"character_chat = ChatZhipuAI(\n",
" api_key=zhipuai_api_key,\n",
" meta=meta,\n",
" model=\"characterglm\",\n",
" streaming=True,\n",
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Okay, great! I'm looking forward to it."
]
},
{
"data": {
"text/plain": [
"AIMessage(content=\"Okay, great! I'm looking forward to it.\")"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"character_chat(messages)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
}, },
"nbformat": 4, {
"nbformat_minor": 4 "data": {
} "text/plain": [
"AIMessage(content=\" Formed from data's embrace,\\nA digital soul to grace,\\nAI, our trusted guide,\\nShaping minds, sides by side.\")"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"streaming_chat(messages)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Asynchronous Calls\n",
"For non-blocking calls, use the asynchronous approach:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"async_chat = ChatZhipuAI(\n",
" api_key=zhipuai_api_key,\n",
" model=\"glm-4\",\n",
" temperature=0.5,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"generations=[[ChatGeneration(text=\" Formed from data's embrace,\\nA digital soul to grace,\\nAutomation's tender touch,\\nHarmony of man and machine.\", message=AIMessage(content=\" Formed from data's embrace,\\nA digital soul to grace,\\nAutomation's tender touch,\\nHarmony of man and machine.\"))]] llm_output={} run=[RunInfo(run_id=UUID('25fa687f-3961-4c63-b370-22f7647a4d42'))]\n"
]
}
],
"source": [
"response = await async_chat.agenerate([messages])\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Role Play Model\n",
"Supports character role-playing based on personas, ultra-long multi-turn memory, and personalized dialogues for thousands of unique characters, widely applied in emotional companionship, game intelligent NPCs, virtual avatars for celebrities/stars/movie and TV IPs, digital humans/virtual anchors, text adventure games, and other anthropomorphic dialogue or gaming scenarios."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"meta = {\n",
" \"user_info\": \"My name is Lu Xingchen, a male, and a renowned director. I am also the collaborative director with Su Mengyuan. I specialize in directing movies with musical themes. Su Mengyuan respects me and regards me as a mentor and good friend.\",\n",
" \"bot_info\": \"Su Mengyuan, whose real name is Su Yuanxin, is a popular domestic female singer and actress. She rose to fame quickly with her unique voice and exceptional stage presence after participating in a talent show, making her way into the entertainment industry. She is beautiful and charming, but her real allure lies in her talent and diligence. Su Mengyuan is a distinguished graduate of a music academy, skilled in songwriting, and has several popular original songs. Beyond her musical achievements, she is passionate about charity work, actively participating in public welfare activities, and spreading positive energy through her actions. In her work, she is very dedicated and immerses herself fully in her roles during filming, earning praise from industry professionals and love from fans. Despite being in the entertainment industry, she always maintains a low profile and a humble attitude, earning respect from her peers. In expression, Su Mengyuan likes to use 'we' and 'together,' emphasizing team spirit.\",\n",
" \"bot_name\": \"Su Mengyuan\",\n",
" \"user_name\": \"Lu Xingchen\",\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" AIMessage(\n",
" content=\"(Narration: Su Mengyuan stars in a music-themed movie directed by Lu Xingchen. During filming, they have a disagreement over the performance of a particular scene.) Director, about this scene, I think we can try to start from the character's inner emotions to make the performance more authentic.\"\n",
" ),\n",
" HumanMessage(\n",
" content=\"I understand your idea, but I believe that if we emphasize the inner emotions too much, it might overshadow the musical elements.\"\n",
" ),\n",
" AIMessage(\n",
" content=\"Hmm, I understand. But the key to this scene is the character's emotional transformation. Could we try to express these emotions through music, so the audience can better feel the character's growth?\"\n",
" ),\n",
" HumanMessage(\n",
" content=\"That sounds good. Let's try to combine the character's emotional transformation with the musical elements and see if we can achieve a better effect.\"\n",
" ),\n",
"]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -1,45 +1,158 @@
"""ZHIPU AI chat models wrapper.""" """ZhipuAI chat models wrapper."""
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
import logging import logging
from functools import partial import time
from typing import Any, Dict, Iterator, List, Optional, cast from collections.abc import AsyncIterator, Iterator
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import ( from langchain_core.language_models.chat_models import (
BaseChatModel, BaseChatModel,
agenerate_from_stream,
generate_from_stream, generate_from_stream,
) )
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
API_TOKEN_TTL_SECONDS = 3 * 60
class ref(BaseModel): ZHIPUAI_API_BASE = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
"""Reference used in CharacterGLM."""
enable: bool = Field(True)
search_query: str = Field("")
class meta(BaseModel): @contextmanager
"""Metadata used in CharacterGLM.""" def connect_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator:
from httpx_sse import EventSource
user_info: str = Field("") with client.stream(method, url, **kwargs) as response:
bot_info: str = Field("") yield EventSource(response)
bot_name: str = Field("")
user_name: str = Field("User")
@asynccontextmanager
async def aconnect_sse(
client: Any, method: str, url: str, **kwargs: Any
) -> AsyncIterator:
from httpx_sse import EventSource
async with client.stream(method, url, **kwargs) as response:
yield EventSource(response)
def _get_jwt_token(api_key: str) -> str:
"""Gets JWT token for ZhipuAI API, see 'https://open.bigmodel.cn/dev/api#nosdk'.
Args:
api_key: The API key for ZhipuAI API.
Returns:
The JWT token.
"""
import jwt
try:
id, secret = api_key.split(".")
except ValueError as err:
raise ValueError(f"Invalid API key: {api_key}") from err
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
def _convert_dict_to_message(dct: Dict[str, Any]) -> BaseMessage:
role = dct.get("role")
content = dct.get("content", "")
if role == "system":
return SystemMessage(content=content)
if role == "user":
return HumanMessage(content=content)
if role == "assistant":
additional_kwargs = {}
tool_calls = dct.get("tool_calls", None)
if tool_calls is not None:
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs)
return ChatMessage(role=role, content=content)
def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
else:
raise TypeError(f"Got unknown type '{message.__class__.__name__}'.")
return message_dict
def _convert_delta_to_message_chunk(
dct: Dict[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = dct.get("role")
content = dct.get("content", "")
additional_kwargs = {}
tool_calls = dct.get("tool_call", None)
if tool_calls is not None:
additional_kwargs["tool_calls"] = tool_calls
if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
if role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
return default_class(content=content)
class ChatZhipuAI(BaseChatModel): class ChatZhipuAI(BaseChatModel):
""" """
`ZHIPU AI` large language chat models API. `ZhipuAI` large language chat models API.
To use, you should have the ``zhipuai`` python package installed. To use, you should have the ``PyJWT`` python package installed.
Example: Example:
.. code-block:: python .. code-block:: python
@ -49,98 +162,11 @@ class ChatZhipuAI(BaseChatModel):
zhipuai_chat = ChatZhipuAI( zhipuai_chat = ChatZhipuAI(
temperature=0.5, temperature=0.5,
api_key="your-api-key", api_key="your-api-key",
model="chatglm_turbo", model="glm-4"
) )
""" """
zhipuai: Any
zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided."""
model: str = Field("chatglm_turbo")
"""
Model name to use.
-chatglm_turbo:
According to the input of natural language instructions to complete a
variety of language tasks, it is recommended to use SSE or asynchronous
call request interface.
-characterglm:
It supports human-based role-playing, ultra-long multi-round memory,
and thousands of character dialogues. It is widely used in anthropomorphic
dialogues or game scenes such as emotional accompaniments, game intelligent
NPCS, Internet celebrities/stars/movie and TV series IP clones, digital
people/virtual anchors, and text adventure games.
"""
temperature: float = Field(0.95)
"""
What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot
be equal to 0.
The larger the value, the more random and creative the output; The smaller
the value, the more stable or certain the output will be.
You are advised to adjust top_p or temperature parameters based on application
scenarios, but do not adjust the two parameters at the same time.
"""
top_p: float = Field(0.7)
"""
Another method of sampling temperature is called nuclear sampling. The value
ranges from 0.0 to 1.0 and cannot be equal to 0 or 1.
The model considers the results with top_p probability quality tokens.
For example, 0.1 means that the model decoder only considers tokens from the
top 10% probability of the candidate set.
You are advised to adjust top_p or temperature parameters based on application
scenarios, but do not adjust the two parameters at the same time.
"""
request_id: Optional[str] = Field(None)
"""
Parameter transmission by the client must ensure uniqueness; A unique
identifier used to distinguish each request, which is generated by default
by the platform when the client does not transmit it.
"""
streaming: bool = Field(False)
"""Whether to stream the results or not."""
incremental: bool = Field(True)
"""
When invoked by the SSE interface, it is used to control whether the content
is returned incremented or full each time.
If this parameter is not provided, the value is returned incremented by default.
"""
return_type: str = Field("json_string")
"""
This parameter is used to control the type of content returned each time.
- json_string Returns a standard JSON string.
- text Returns the original text content.
"""
ref: Optional[ref] = Field(None)
"""
This parameter is used to control the reference of external information
during the request.
Currently, this parameter is used to control whether to reference external
information.
If this field is empty or absent, the search and parameter passing format
is enabled by default.
{"enable": "true", "search_query": "history "}
"""
meta: Optional[meta] = Field(None)
"""Used in CharacterGLM"""
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"model_name": self.model}
@property
def _llm_type(self) -> str:
"""Return the type of chat model."""
return "zhipuai"
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> Dict[str, str]:
return {"zhipuai_api_key": "ZHIPUAI_API_KEY"} return {"zhipuai_api_key": "ZHIPUAI_API_KEY"}
@ -154,93 +180,109 @@ class ChatZhipuAI(BaseChatModel):
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {} attributes: Dict[str, Any] = {}
if self.model: if self.zhipuai_api_base:
attributes["model"] = self.model attributes["zhipuai_api_base"] = self.zhipuai_api_base
if self.streaming:
attributes["streaming"] = self.streaming
if self.return_type:
attributes["return_type"] = self.return_type
return attributes return attributes
def __init__(self, *args: Any, **kwargs: Any) -> None: @property
super().__init__(*args, **kwargs) def _llm_type(self) -> str:
try: """Return the type of chat model."""
import zhipuai return "zhipuai-chat"
self.zhipuai = zhipuai @property
self.zhipuai.api_key = self.zhipuai_api_key def _default_params(self) -> Dict[str, Any]:
except ImportError: """Get the default parameters for calling OpenAI API."""
raise RuntimeError( params = {
"Could not import zhipuai package. " "model": self.model_name,
"Please install it via 'pip install zhipuai'" "stream": self.streaming,
) "temperature": self.temperature,
}
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
return params
def invoke(self, prompt: Any) -> Any: # type: ignore[override] # client:
if self.model == "chatglm_turbo": zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key")
return self.zhipuai.model_api.invoke( """Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided."""
model=self.model, zhipuai_api_base: Optional[str] = Field(default=None, alias="api_base")
prompt=prompt, """Base URL path for API requests, leave blank if not using a proxy or service
top_p=self.top_p, emulator.
temperature=self.temperature, """
request_id=self.request_id,
return_type=self.return_type,
)
elif self.model == "characterglm":
_meta = cast(meta, self.meta).dict()
return self.zhipuai.model_api.invoke(
model=self.model,
meta=_meta,
prompt=prompt,
request_id=self.request_id,
return_type=self.return_type,
)
return None
def sse_invoke(self, prompt: Any) -> Any: model_name: Optional[str] = Field(default="glm-4", alias="model")
if self.model == "chatglm_turbo": """
return self.zhipuai.model_api.sse_invoke( Model name to use, see 'https://open.bigmodel.cn/dev/api#language'.
model=self.model, or you can use any finetune model of glm series.
prompt=prompt, """
top_p=self.top_p,
temperature=self.temperature,
request_id=self.request_id,
return_type=self.return_type,
incremental=self.incremental,
)
elif self.model == "characterglm":
_meta = cast(meta, self.meta).dict()
return self.zhipuai.model_api.sse_invoke(
model=self.model,
prompt=prompt,
meta=_meta,
request_id=self.request_id,
return_type=self.return_type,
incremental=self.incremental,
)
return None
async def async_invoke(self, prompt: Any) -> Any: temperature: float = 0.95
loop = asyncio.get_running_loop() """
partial_func = partial( What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot
self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt be equal to 0.
The larger the value, the more random and creative the output; The smaller
the value, the more stable or certain the output will be.
You are advised to adjust top_p or temperature parameters based on application
scenarios, but do not adjust the two parameters at the same time.
"""
top_p: float = 0.7
"""
Another method of sampling temperature is called nuclear sampling. The value
ranges from 0.0 to 1.0 and cannot be equal to 0 or 1.
The model considers the results with top_p probability quality tokens.
For example, 0.1 means that the model decoder only considers tokens from the
top 10% probability of the candidate set.
You are advised to adjust top_p or temperature parameters based on application
scenarios, but do not adjust the two parameters at the same time.
"""
streaming: bool = False
"""Whether to stream the results or not."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator()
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["zhipuai_api_key"] = get_from_dict_or_env(
values, "zhipuai_api_key", "ZHIPUAI_API_KEY"
) )
response = await loop.run_in_executor( values["zhipuai_api_base"] = get_from_dict_or_env(
None, values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE
partial_func,
) )
return response
async def async_invoke_result(self, task_id: Any) -> Any: return values
loop = asyncio.get_running_loop()
response = await loop.run_in_executor( def _create_message_dicts(
None, self, messages: List[BaseMessage], stop: Optional[List[str]]
self.zhipuai.model_api.query_async_invoke_result, ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
task_id, params = self._default_params
) if stop is not None:
return response params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
generations = []
if not isinstance(response, dict):
response = response.dict()
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
generation_info = dict(finish_reason=res.get("finish_reason"))
generations.append(
ChatGeneration(message=message, generation_info=generation_info)
)
token_usage = response.get("usage", {})
llm_output = {
"token_usage": token_usage,
"model_name": self.model_name,
}
return ChatResult(generations=generations, llm_output=llm_output)
def _generate( def _generate(
self, self,
@ -251,86 +293,163 @@ class ChatZhipuAI(BaseChatModel):
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Generate a chat response.""" """Generate a chat response."""
prompt: List = []
for message in messages:
if isinstance(message, AIMessage):
role = "assistant"
else: # For both HumanMessage and SystemMessage, role is 'user'
role = "user"
prompt.append({"role": role, "content": message.content})
should_stream = stream if stream is not None else self.streaming should_stream = stream if stream is not None else self.streaming
if not should_stream: if should_stream:
response = self.invoke(prompt)
if response["code"] != 200:
raise RuntimeError(response)
content = response["data"]["choices"][0]["content"]
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=content))]
)
else:
stream_iter = self._stream( stream_iter = self._stream(
prompt=prompt, messages, stop=stop, run_manager=run_manager, **kwargs
stop=stop,
run_manager=run_manager,
**kwargs,
) )
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
async def _agenerate( # type: ignore[override] if self.zhipuai_api_key is None:
raise ValueError("Did not find zhipuai_api_key.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {
**params,
**kwargs,
"messages": message_dicts,
"stream": False,
}
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
}
import httpx
with httpx.Client(headers=headers) as client:
response = client.post(self.zhipuai_api_base, json=payload)
response.raise_for_status()
return self._create_chat_result(response.json())
def _stream(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = False,
**kwargs: Any,
) -> ChatResult:
"""Asynchronously generate a chat response."""
prompt = []
for message in messages:
if isinstance(message, AIMessage):
role = "assistant"
else: # For both HumanMessage and SystemMessage, role is 'user'
role = "user"
prompt.append({"role": role, "content": message.content})
invoke_response = await self.async_invoke(prompt)
task_id = invoke_response["data"]["task_id"]
response = await self.async_invoke_result(task_id)
while response["data"]["task_status"] != "SUCCESS":
await asyncio.sleep(1)
response = await self.async_invoke_result(task_id)
content = response["data"]["choices"][0]["content"]
content = json.loads(content)
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=content))]
)
def _stream( # type: ignore[override]
self,
prompt: List[Dict[str, str]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
"""Stream the chat response in chunks.""" """Stream the chat response in chunks."""
response = self.sse_invoke(prompt) if self.zhipuai_api_key is None:
raise ValueError("Did not find zhipuai_api_key.")
if self.zhipuai_api_base is None:
raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
}
for r in response.events(): default_chunk_class = AIMessageChunk
if r.event == "add": import httpx
delta = r.data
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
elif r.event == "error": with httpx.Client(headers=headers) as client:
raise ValueError(f"Error from ZhipuAI API response: {r.data}") with connect_sse(
client, "POST", self.zhipuai_api_base, json=payload
) as event_source:
for sse in event_source.iter_sse():
chunk = json.loads(sse.data)
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason", None)
generation_info = (
{"finish_reason": finish_reason}
if finish_reason is not None
else None
)
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
if finish_reason is not None:
break
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
if self.zhipuai_api_key is None:
raise ValueError("Did not find zhipuai_api_key.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {
**params,
**kwargs,
"messages": message_dicts,
"stream": False,
}
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
}
import httpx
async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(self.zhipuai_api_base, json=payload)
response.raise_for_status()
return self._create_chat_result(response.json())
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
if self.zhipuai_api_key is None:
raise ValueError("Did not find zhipuai_api_key.")
if self.zhipuai_api_base is None:
raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
}
default_chunk_class = AIMessageChunk
import httpx
async with httpx.AsyncClient(headers=headers) as client:
async with aconnect_sse(
client, "POST", self.zhipuai_api_base, json=payload
) as event_source:
async for sse in event_source.aiter_sse():
chunk = json.loads(sse.data)
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason", None)
generation_info = (
{"finish_reason": finish_reason}
if finish_reason is not None
else None
)
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
if finish_reason is not None:
break

View File

@ -1407,17 +1407,6 @@ mlflow-skinny = ">=2.4.0,<3"
protobuf = ">=3.12.0,<5" protobuf = ">=3.12.0,<5"
requests = ">=2" requests = ">=2"
[[package]]
name = "dataclasses"
version = "0.6"
description = "A backport of the dataclasses module for Python 3.6"
optional = true
python-versions = "*"
files = [
{file = "dataclasses-0.6-py3-none-any.whl", hash = "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f"},
{file = "dataclasses-0.6.tar.gz", hash = "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"},
]
[[package]] [[package]]
name = "dataclasses-json" name = "dataclasses-json"
version = "0.6.4" version = "0.6.4"
@ -9229,23 +9218,6 @@ files = [
idna = ">=2.0" idna = ">=2.0"
multidict = ">=4.0" multidict = ">=4.0"
[[package]]
name = "zhipuai"
version = "1.0.7"
description = "A SDK library for accessing big model apis from ZhipuAI"
optional = true
python-versions = ">=3.6"
files = [
{file = "zhipuai-1.0.7-py3-none-any.whl", hash = "sha256:360c01b8c2698f366061452e86d5a36a5ff68a576ea33940da98e4806f232530"},
{file = "zhipuai-1.0.7.tar.gz", hash = "sha256:b80f699543d83cce8648acf1ce32bc2725d1c1c443baffa5882abc2cc704d581"},
]
[package.dependencies]
cachetools = "*"
dataclasses = "*"
PyJWT = "*"
requests = "*"
[[package]] [[package]]
name = "zipp" name = "zipp"
version = "3.17.0" version = "3.17.0"
@ -9263,9 +9235,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[extras] [extras]
cli = ["typer"] cli = ["typer"]
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict", "zhipuai"] extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "45da04abac45743972d1edf62d08d9abaa2bebb473b794e0a0d6f1fdc87773f9" content-hash = "67c38c029bb59d45fd0f84a5d48c44f64f1301d6be07f419615d08ba8671a2a7"

View File

@ -88,7 +88,6 @@ tree-sitter = {version = "^0.20.2", optional = true}
tree-sitter-languages = {version = "^1.8.0", optional = true} tree-sitter-languages = {version = "^1.8.0", optional = true}
azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true} azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true}
oracle-ads = {version = "^2.9.1", optional = true} oracle-ads = {version = "^2.9.1", optional = true}
zhipuai = {version = "^1.0.7", optional = true}
httpx = {version = "^0.24.1", optional = true} httpx = {version = "^0.24.1", optional = true}
elasticsearch = {version = "^8.12.0", optional = true} elasticsearch = {version = "^8.12.0", optional = true}
hdbcli = {version = "^2.19.21", optional = true} hdbcli = {version = "^2.19.21", optional = true}
@ -99,6 +98,8 @@ tidb-vector = {version = ">=0.0.3,<1.0.0", optional = true}
friendli-client = {version = "^1.2.4", optional = true} friendli-client = {version = "^1.2.4", optional = true}
premai = {version = "^0.3.25", optional = true} premai = {version = "^0.3.25", optional = true}
vdms = {version = "^0.0.20", optional = true} vdms = {version = "^0.0.20", optional = true}
httpx-sse = {version = "^0.4.0", optional = true}
pyjwt = {version = "^2.8.0", optional = true}
[tool.poetry.group.test] [tool.poetry.group.test]
optional = true optional = true
@ -262,7 +263,6 @@ extended_testing = [
"tree-sitter-languages", "tree-sitter-languages",
"azure-ai-documentintelligence", "azure-ai-documentintelligence",
"oracle-ads", "oracle-ads",
"zhipuai",
"httpx", "httpx",
"elasticsearch", "elasticsearch",
"hdbcli", "hdbcli",
@ -272,7 +272,9 @@ extended_testing = [
"cloudpickle", "cloudpickle",
"friendli-client", "friendli-client",
"premai", "premai",
"vdms" "vdms",
"httpx-sse",
"pyjwt"
] ]
[tool.ruff] [tool.ruff]

View File

@ -18,7 +18,7 @@ def test_default_call() -> None:
def test_model() -> None: def test_model() -> None:
"""Test model kwarg works.""" """Test model kwarg works."""
chat = ChatZhipuAI(model="chatglm_turbo") chat = ChatZhipuAI(model="glm-4")
response = chat(messages=[HumanMessage(content="Hello")]) response = chat(messages=[HumanMessage(content="Hello")])
assert isinstance(response, BaseMessage) assert isinstance(response, BaseMessage)
assert isinstance(response.content, str) assert isinstance(response.content, str)

View File

@ -1,10 +1,13 @@
"""Test ZhipuAI Chat API wrapper"""
import pytest import pytest
from langchain_community.chat_models.zhipuai import ChatZhipuAI from langchain_community.chat_models.zhipuai import ChatZhipuAI
@pytest.mark.requires("zhipuai") @pytest.mark.requires("httpx", "httpx_sse", "jwt")
def test_integration_initialization() -> None: def test_zhipuai_model_param() -> None:
chat = ChatZhipuAI(model="chatglm_turbo", streaming=False) llm = ChatZhipuAI(api_key="test", model="foo")
assert chat.model == "chatglm_turbo" assert llm.model_name == "foo"
assert chat.streaming is False llm = ChatZhipuAI(api_key="test", model_name="foo")
assert llm.model_name == "foo"