mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 02:13:23 +00:00
Implement Alibaba Tongyi chat model apis. (#10922)
Hi there This PR is aim to implement chat model for Alibaba Tongyi LLM model. It contains work below: 1.Implement ChatTongyi chat model in langchain.chat_models.tongyi. Note this is different with tongyi llm model to another PR https://github.com/langchain-ai/langchain/pull/10878. For detail it implements _generate() and _stream() function in ChatTongyi. 2. Add some examples in chat/tongyi.ipynb. 3. Add integration test in chat_models/test_tongyi.py Note async completion for the Text API is not yet supported. Dependencies: dashscope. It will be installed manually cause it is not need by everyone.
This commit is contained in:
parent
008348ce71
commit
11cdfe44af
163
docs/extras/integrations/chat/tongyi.ipynb
Normal file
163
docs/extras/integrations/chat/tongyi.ipynb
Normal file
@ -0,0 +1,163 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# Tongyi Qwen\n",
|
||||
"Tongyi Qwen is a large language model developed by Alibaba's Damo Academy. It is capable of understanding user intent through natural language understanding and semantic analysis, based on user input in natural language. It provides services and assistance to users in different domains and tasks. By providing clear and detailed instructions, you can obtain results that better align with your expectations.\n",
|
||||
"In this notebook, we will introduce how to use langchain with [Tongyi](https://www.aliyun.com/product/dashscope) mainly in `Chat` corresponding\n",
|
||||
" to the package `langchain/chat_models` in langchain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install the package\n",
|
||||
"!pip install dashscope"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdin",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" ········\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Get a new token: https://help.aliyun.com/document_detail/611472.html?spm=a2c4g.2399481.0.0\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"DASHSCOPE_API_KEY = getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"DASHSCOPE_API_KEY\"] = DASHSCOPE_API_KEY"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"chat resp: content='Hello! How' additional_kwargs={} example=False\n",
|
||||
"chat resp: content=' can I assist you today?' additional_kwargs={} example=False\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chat_models.tongyi import ChatTongyi\n",
|
||||
"from langchain.schema import HumanMessage\n",
|
||||
"\n",
|
||||
"chatLLM = ChatTongyi(\n",
|
||||
" streaming=True,\n",
|
||||
")\n",
|
||||
"res = chatLLM.stream([HumanMessage(content=\"hi\")], streaming=True)\n",
|
||||
"for r in res:\n",
|
||||
" print(\"chat resp:\", r)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessageChunk(content=\"J'aime programmer.\", additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.schema import AIMessage, HumanMessage, SystemMessage\n",
|
||||
"messages = [\n",
|
||||
" SystemMessage(\n",
|
||||
" content=\"You are a helpful assistant that translates English to French.\"\n",
|
||||
" ),\n",
|
||||
" HumanMessage(\n",
|
||||
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"chatLLM(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
391
libs/langchain/langchain/chat_models/tongyi.py
Normal file
391
libs/langchain/langchain/chat_models/tongyi.py
Normal file
@ -0,0 +1,391 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
from tenacity import (
|
||||
RetryCallState,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import (
|
||||
BaseChatModel,
|
||||
_generate_from_stream,
|
||||
)
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
content = _dict.get("content", "") or ""
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
elif role == "function":
|
||||
return FunctionMessage(content=_dict["content"], name=_dict["name"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
def _stream_response_to_generation_chunk(
|
||||
stream_response: Dict[str, Any],
|
||||
length: int,
|
||||
) -> GenerationChunk:
|
||||
"""Convert a stream response to a generation chunk.
|
||||
|
||||
As the low level API implement is different from openai and other llm.
|
||||
Stream response of Tongyi is not split into chunks, but all data generated before.
|
||||
For example, the answer 'Hi Pickle Rick! How can I assist you today?'
|
||||
Other llm will stream answer:
|
||||
'Hi Pickle',
|
||||
' Rick!',
|
||||
' How can I assist you today?'.
|
||||
|
||||
Tongyi answer:
|
||||
'Hi Pickle',
|
||||
'Hi Pickle Rick!',
|
||||
'Hi Pickle Rick! How can I assist you today?'.
|
||||
|
||||
As the GenerationChunk is implemented with chunks. Only return full_text[length:]
|
||||
for new chunk.
|
||||
"""
|
||||
full_text = stream_response["output"]["text"]
|
||||
text = full_text[length:]
|
||||
finish_reason = stream_response["output"].get("finish_reason", None)
|
||||
|
||||
return GenerationChunk(
|
||||
text=text,
|
||||
generation_info=dict(
|
||||
finish_reason=finish_reason,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatTongyi,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
def _before_sleep(retry_state: RetryCallState) -> None:
|
||||
if run_manager:
|
||||
run_manager.on_retry(retry_state)
|
||||
return None
|
||||
|
||||
min_seconds = 1
|
||||
max_seconds = 4
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(llm.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(retry_if_exception_type(HTTPError)),
|
||||
before_sleep=_before_sleep,
|
||||
)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any],
|
||||
default_class: type[BaseMessageChunk],
|
||||
length: int,
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
full_content = _dict.get("content") or ""
|
||||
content = full_content[length:]
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content)
|
||||
|
||||
|
||||
class ChatTongyi(BaseChatModel):
|
||||
"""Alibaba Tongyi Qwen chat models API.
|
||||
|
||||
To use, you should have the ``dashscope`` python package installed,
|
||||
and set env ``DASHSCOPE_API_KEY`` with your API key, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import Tongyi
|
||||
Tongyi_chat = ChatTongyi()
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"dashscope_api_key": "DASHSCOPE_API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = Field(default="qwen-turbo", alias="model")
|
||||
|
||||
"""Model name to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
top_p: float = 0.8
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
|
||||
dashscope_api_key: Optional[str] = None
|
||||
"""Dashscope api key provide by alicloud."""
|
||||
|
||||
n: int = 1
|
||||
"""How many completions to generate for each prompt."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
max_retries: int = 10
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
prefix_messages: List = Field(default_factory=list)
|
||||
"""Series of messages for Chat input."""
|
||||
|
||||
result_format: str = Field(default="message")
|
||||
"""Return result format"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "tongyi"
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
get_from_dict_or_env(values, "dashscope_api_key", "DASHSCOPE_API_KEY")
|
||||
try:
|
||||
import dashscope
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import dashscope python package. "
|
||||
"Please install it with `pip install dashscope --upgrade`."
|
||||
)
|
||||
try:
|
||||
values["client"] = dashscope.Generation
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`dashscope` has no `Generation` attribute, this is likely "
|
||||
"due to an old version of the dashscope package. Try upgrading it "
|
||||
"with `pip install --upgrade dashscope`."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"top_p": self.top_p,
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"result_format": self.result_format,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
def completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**_kwargs: Any) -> Any:
|
||||
resp = self.client.call(**_kwargs)
|
||||
if resp.status_code == 200:
|
||||
return resp
|
||||
elif resp.status_code in [400, 401]:
|
||||
raise ValueError(
|
||||
f"status_code: {resp.status_code} \n "
|
||||
f"code: {resp.code} \n message: {resp.message}"
|
||||
)
|
||||
else:
|
||||
raise HTTPError(
|
||||
f"HTTP error occurred: status_code: {resp.status_code} \n "
|
||||
f"code: {resp.code} \n message: {resp.message}"
|
||||
)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
def stream_completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _stream_completion_with_retry(**_kwargs: Any) -> Any:
|
||||
return self.client.call(**_kwargs)
|
||||
|
||||
return _stream_completion_with_retry(**kwargs)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
# Mark current chunk total length
|
||||
length = 0
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.stream_completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk["output"]["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["output"]["choices"][0]
|
||||
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["message"], default_chunk_class, length
|
||||
)
|
||||
finish_reason = choice.get("finish_reason")
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=chunk)
|
||||
length = len(choice["message"]["content"])
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._client_params()
|
||||
|
||||
# Ensure `stop` is a list of strings
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used for the openai client."""
|
||||
creds: Dict[str, Any] = {
|
||||
"dashscope_api_key": self.dashscope_api_key,
|
||||
}
|
||||
return {**self._default_params, **creds}
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for res in response["output"]["choices"]:
|
||||
message = convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.get("finish_reason")),
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
@ -0,0 +1,77 @@
|
||||
"""Test Alibaba Tongyi Chat Model."""
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.tongyi import ChatTongyi
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
)
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_default_call() -> None:
|
||||
"""Test default model call."""
|
||||
chat = ChatTongyi()
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_model() -> None:
|
||||
"""Test model kwarg works."""
|
||||
chat = ChatTongyi(model="qwen-plus")
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_history() -> None:
|
||||
"""Tests multiple history works."""
|
||||
chat = ChatTongyi()
|
||||
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test that stream works."""
|
||||
chat = ChatTongyi(streaming=True)
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="Who are you?"),
|
||||
],
|
||||
stream=True,
|
||||
callbacks=callback_manager,
|
||||
)
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
"""Tests multiple messages works."""
|
||||
chat = ChatTongyi()
|
||||
message = HumanMessage(content="Hi, how are you.")
|
||||
response = chat.generate([[message], [message]])
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
Loading…
Reference in New Issue
Block a user