mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 08:32:32 +00:00
community: Implement bind_tools
for ChatTongyi (#20725)
## Description Implement `bind_tools` in ChatTongyi. Usage example: ```py from langchain_core.tools import tool from langchain_community.chat_models.tongyi import ChatTongyi @tool def multiply(first_int: int, second_int: int) -> int: """Multiply two integers together.""" return first_int * second_int llm = ChatTongyi(model="qwen-turbo") llm_with_tools = llm.bind_tools([multiply]) msg = llm_with_tools.invoke("What's 5 times forty two") print(msg) ``` Streaming is also supported. ## Dependencies No Dependency is required for this change. --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
b216a1dddb
commit
0ead09f84d
@ -26,14 +26,22 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Install the package\n",
|
||||
"%pip install --upgrade --quiet dashscope"
|
||||
@ -48,15 +56,7 @@
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" ········\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Get a new token: https://help.aliyun.com/document_detail/611472.html?spm=a2c4g.2399481.0.0\n",
|
||||
"from getpass import getpass\n",
|
||||
@ -94,8 +94,12 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"chat resp: content='Hello! How' additional_kwargs={} example=False\n",
|
||||
"chat resp: content=' can I assist you today?' additional_kwargs={} example=False\n"
|
||||
"chat resp: content='Hello' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
|
||||
"chat resp: content='!' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
|
||||
"chat resp: content=' How' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
|
||||
"chat resp: content=' can I assist you today' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
|
||||
"chat resp: content='?' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
|
||||
"chat resp: content='' response_metadata={'finish_reason': 'stop', 'request_id': '921db2c5-4d53-9a89-8e87-e4ad6a671237', 'token_usage': {'input_tokens': 20, 'output_tokens': 9, 'total_tokens': 29}} id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -116,10 +120,18 @@
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/cheese/PARA/Projects/langchain-contribution/langchain/libs/core/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The method `BaseChatModel.__call__` was deprecated in langchain-core 0.1.7 and will be removed in 0.2.0. Use invoke instead.\n",
|
||||
" warn_deprecated(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessageChunk(content=\"J'aime programmer.\", additional_kwargs={}, example=False)"
|
||||
"AIMessage(content=\"J'adore programmer.\", response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'stop', 'request_id': 'ae725086-0ffa-9728-8c72-b204c7bc7eeb', 'token_usage': {'input_tokens': 36, 'output_tokens': 6, 'total_tokens': 42}}, id='run-060cc103-ef5f-4c8a-af40-792ac7f40c26-0')"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
@ -149,18 +161,65 @@
|
||||
"ChatTongyi supports tool calling API that lets you describe tools and their arguments, and have the model return a JSON object with a tool to invoke and the inputs to that tool."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Use with `bind_tools`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"content='' additional_kwargs={'tool_calls': [{'function': {'name': 'multiply', 'arguments': '{\"first_int\": 5, \"second_int\": 42}'}, 'id': '', 'type': 'function'}]} response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': '4acf0e36-44af-987a-a0c0-8b5c5eaa1a8b', 'token_usage': {'input_tokens': 200, 'output_tokens': 25, 'total_tokens': 225}} id='run-0ecd0f09-1d20-4e55-a4f3-f14d1f710ae7-0' tool_calls=[{'name': 'multiply', 'args': {'first_int': 5, 'second_int': 42}, 'id': ''}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_community.chat_models.tongyi import ChatTongyi\n",
|
||||
"from langchain_core.tools import tool\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def multiply(first_int: int, second_int: int) -> int:\n",
|
||||
" \"\"\"Multiply two integers together.\"\"\"\n",
|
||||
" return first_int * second_int\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"llm = ChatTongyi(model=\"qwen-turbo\")\n",
|
||||
"\n",
|
||||
"llm_with_tools = llm.bind_tools([multiply])\n",
|
||||
"\n",
|
||||
"msg = llm_with_tools.invoke(\"What's 5 times forty two\")\n",
|
||||
"\n",
|
||||
"print(msg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Construct args manually"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='', additional_kwargs={'tool_calls': [{'function': {'name': 'get_current_weather', 'arguments': '{\"location\": \"San Francisco\"}'}, 'id': '', 'type': 'function'}]}, response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': 'dae79197-8780-9b7e-8c15-6a83e2a53534', 'token_usage': {'input_tokens': 229, 'output_tokens': 19, 'total_tokens': 248}}, id='run-9e06f837-582b-473b-bb1f-5e99a68ecc10-0', tool_calls=[{'name': 'get_current_weather', 'args': {'location': 'San Francisco'}, 'id': ''}])"
|
||||
"AIMessage(content='', additional_kwargs={'tool_calls': [{'function': {'name': 'get_current_weather', 'arguments': '{\"location\": \"San Francisco\"}'}, 'id': '', 'type': 'function'}]}, response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': '87ef33d2-5c6b-9457-91e2-39faad7120eb', 'token_usage': {'input_tokens': 229, 'output_tokens': 19, 'total_tokens': 248}}, id='run-7939ba7f-e3f7-46f8-980b-30499b52723c-0', tool_calls=[{'name': 'get_current_weather', 'args': {'location': 'San Francisco'}, 'id': ''}])"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -224,7 +283,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
@ -12,6 +13,8 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -20,6 +23,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@ -32,6 +36,8 @@ from langchain_core.messages import (
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
make_invalid_tool_call,
|
||||
@ -42,8 +48,11 @@ from langchain_core.outputs import (
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from requests.exceptions import HTTPError
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
@ -68,6 +77,7 @@ def convert_dict_to_message(
|
||||
"""Convert a dict to a message."""
|
||||
role = _dict["role"]
|
||||
content = _dict["content"]
|
||||
|
||||
if role == "user":
|
||||
return (
|
||||
HumanMessageChunk(content=content)
|
||||
@ -79,17 +89,39 @@ def convert_dict_to_message(
|
||||
invalid_tool_calls = []
|
||||
if "tool_calls" in _dict:
|
||||
additional_kwargs = {"tool_calls": _dict["tool_calls"]}
|
||||
for raw_tool_call in _dict["tool_calls"]:
|
||||
try:
|
||||
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
|
||||
except Exception as e:
|
||||
invalid_tool_calls.append(
|
||||
make_invalid_tool_call(raw_tool_call, str(e))
|
||||
)
|
||||
|
||||
for index, value in enumerate(_dict["tool_calls"]):
|
||||
if is_chunk:
|
||||
try:
|
||||
tool_calls.append(
|
||||
{
|
||||
"name": value["function"].get("name"),
|
||||
"args": value["function"].get("arguments"),
|
||||
"id": value.get("id"),
|
||||
# Tongyi does not respond with index,
|
||||
# use index in the list instead
|
||||
"index": index,
|
||||
}
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
parsed_tool = parse_tool_call(value, return_id=True)
|
||||
if parsed_tool:
|
||||
tool_calls.append(parsed_tool)
|
||||
except Exception as e:
|
||||
invalid_tool_calls.append(make_invalid_tool_call(value, str(e)))
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
|
||||
return (
|
||||
AIMessageChunk(content=content)
|
||||
AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_calls,
|
||||
id=_dict.get("id"),
|
||||
)
|
||||
if is_chunk
|
||||
else AIMessage(
|
||||
content=content,
|
||||
@ -104,6 +136,23 @@ def convert_dict_to_message(
|
||||
if is_chunk
|
||||
else SystemMessage(content=content)
|
||||
)
|
||||
elif role == "tool":
|
||||
additional_kwargs = {}
|
||||
if "name" in _dict:
|
||||
additional_kwargs["name"] = _dict["name"]
|
||||
return (
|
||||
ToolMessageChunk(
|
||||
content=_dict.get("content", ""),
|
||||
tool_call_id=_dict.get("tool_call_id"),
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
if is_chunk
|
||||
else ToolMessage(
|
||||
content=_dict.get("content", ""),
|
||||
tool_call_id=_dict.get("tool_call_id"),
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return (
|
||||
ChatMessageChunk(role=role, content=content)
|
||||
@ -113,17 +162,23 @@ def convert_dict_to_message(
|
||||
|
||||
|
||||
def convert_message_chunk_to_message(message_chunk: BaseMessageChunk) -> BaseMessage:
|
||||
"""Convert a message chunk to a message."""
|
||||
if isinstance(message_chunk, HumanMessageChunk):
|
||||
return HumanMessage(content=message_chunk.content)
|
||||
elif isinstance(message_chunk, AIMessageChunk):
|
||||
return AIMessage(content=message_chunk.content)
|
||||
elif isinstance(message_chunk, SystemMessageChunk):
|
||||
return SystemMessage(content=message_chunk.content)
|
||||
elif isinstance(message_chunk, ChatMessageChunk):
|
||||
return ChatMessage(role=message_chunk.role, content=message_chunk.content)
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message_chunk}")
|
||||
"""Convert a message chunk to a message.
|
||||
|
||||
Args:
|
||||
chunk: Message chunk to convert.
|
||||
|
||||
Returns:
|
||||
Message.
|
||||
"""
|
||||
if not isinstance(message_chunk, BaseMessageChunk):
|
||||
return message_chunk
|
||||
# chunk classes always have the equivalent non-chunk class as their first parent
|
||||
ignore_keys = ["type"]
|
||||
if isinstance(message_chunk, AIMessageChunk):
|
||||
ignore_keys.append("tool_call_chunks")
|
||||
return message_chunk.__class__.__mro__[1](
|
||||
**{k: v for k, v in message_chunk.__dict__.items() if k not in ignore_keys}
|
||||
)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
@ -136,8 +191,17 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "tool_calls" in message.additional_kwargs:
|
||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
@ -256,11 +320,57 @@ class ChatTongyi(BaseChatModel):
|
||||
@retry_decorator
|
||||
def _stream_completion_with_retry(**_kwargs: Any) -> Any:
|
||||
responses = self.client.call(**_kwargs)
|
||||
prev_resp = None
|
||||
|
||||
for resp in responses:
|
||||
yield check_response(resp)
|
||||
# If we are streaming without `incremental_output = True`,
|
||||
# we need to calculate the delta response manually
|
||||
if _kwargs.get("stream") and not _kwargs.get(
|
||||
"incremental_output", False
|
||||
):
|
||||
if prev_resp is None:
|
||||
delta_resp = resp
|
||||
else:
|
||||
delta_resp = self.subtract_client_response(resp, prev_resp)
|
||||
prev_resp = resp
|
||||
yield check_response(delta_resp)
|
||||
else:
|
||||
yield check_response(resp)
|
||||
|
||||
return _stream_completion_with_retry(**kwargs)
|
||||
|
||||
def subtract_client_response(self, resp: Any, prev_resp: Any) -> Any:
|
||||
"""Subtract prev response from curr response.
|
||||
|
||||
Useful when streaming without `incremental_output = True`
|
||||
"""
|
||||
|
||||
resp_copy = json.loads(json.dumps(resp))
|
||||
choice = resp_copy["output"]["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
prev_resp_copy = json.loads(json.dumps(prev_resp))
|
||||
prev_choice = prev_resp_copy["output"]["choices"][0]
|
||||
prev_message = prev_choice["message"]
|
||||
|
||||
message["content"] = message["content"].replace(prev_message["content"], "")
|
||||
|
||||
if message.get("tool_calls"):
|
||||
for index, tool_call in enumerate(message["tool_calls"]):
|
||||
function = tool_call["function"]
|
||||
|
||||
if prev_message.get("tool_calls"):
|
||||
prev_function = prev_message["tool_calls"][index]["function"]
|
||||
|
||||
function["name"] = function["name"].replace(
|
||||
prev_function["name"], ""
|
||||
)
|
||||
function["arguments"] = function["arguments"].replace(
|
||||
prev_function["arguments"], ""
|
||||
)
|
||||
|
||||
return resp_copy
|
||||
|
||||
async def astream_completion_with_retry(self, **kwargs: Any) -> Any:
|
||||
"""Because the dashscope SDK doesn't provide an async API,
|
||||
we wrap `stream_generate_with_retry` with an async generator."""
|
||||
@ -301,16 +411,16 @@ class ChatTongyi(BaseChatModel):
|
||||
) -> ChatResult:
|
||||
generations = []
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
generation_chunk: Optional[ChatGenerationChunk] = None
|
||||
for chunk in self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
if generation_chunk is None:
|
||||
generation_chunk = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
generations.append(self._chunk_to_generation(generation))
|
||||
generation_chunk += chunk
|
||||
assert generation_chunk is not None
|
||||
generations.append(self._chunk_to_generation(generation_chunk))
|
||||
else:
|
||||
params: Dict[str, Any] = self._invocation_params(
|
||||
messages=messages, stop=stop, **kwargs
|
||||
@ -373,9 +483,19 @@ class ChatTongyi(BaseChatModel):
|
||||
params: Dict[str, Any] = self._invocation_params(
|
||||
messages=messages, stop=stop, stream=True, **kwargs
|
||||
)
|
||||
|
||||
for stream_resp, is_last_chunk in generate_with_last_element_mark(
|
||||
self.stream_completion_with_retry(**params)
|
||||
):
|
||||
choice = stream_resp["output"]["choices"][0]
|
||||
message = choice["message"]
|
||||
if (
|
||||
choice["finish_reason"] == "null"
|
||||
and message["content"] == ""
|
||||
and "tool_calls" not in message
|
||||
):
|
||||
continue
|
||||
|
||||
chunk = ChatGenerationChunk(
|
||||
**self._chat_generation_from_qwen_resp(
|
||||
stream_resp, is_chunk=True, is_last_chunk=is_last_chunk
|
||||
@ -413,14 +533,13 @@ class ChatTongyi(BaseChatModel):
|
||||
params = {**self._default_params, **kwargs}
|
||||
if stop is not None:
|
||||
params["stop"] = stop
|
||||
if params.get("stream"):
|
||||
# According to the Tongyi official docs,
|
||||
# `incremental_output` with `tools` is not supported yet
|
||||
if params.get("stream") and not params.get("tools"):
|
||||
params["incremental_output"] = True
|
||||
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
|
||||
# According to the docs, the last message should be a `user` message
|
||||
if message_dicts[-1]["role"] != "user":
|
||||
raise ValueError("Last message should be user message.")
|
||||
# And the `system` message should be the first message if present
|
||||
system_message_indices = [
|
||||
i for i, m in enumerate(message_dicts) if m["role"] == "system"
|
||||
@ -470,3 +589,22 @@ class ChatTongyi(BaseChatModel):
|
||||
message=convert_message_chunk_to_message(chunk.message),
|
||||
generation_info=chunk.generation_info,
|
||||
)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
@ -55,17 +55,17 @@ def _create_retry_decorator(llm: Tongyi) -> Callable[[Any], Any]:
|
||||
|
||||
def check_response(resp: Any) -> Any:
|
||||
"""Check the response from the completion call."""
|
||||
if resp.status_code == 200:
|
||||
if resp["status_code"] == 200:
|
||||
return resp
|
||||
elif resp.status_code in [400, 401]:
|
||||
elif resp["status_code"] in [400, 401]:
|
||||
raise ValueError(
|
||||
f"status_code: {resp.status_code} \n "
|
||||
f"code: {resp.code} \n message: {resp.message}"
|
||||
f"status_code: {resp['status_code']} \n "
|
||||
f"code: {resp['code']} \n message: {resp['message']}"
|
||||
)
|
||||
else:
|
||||
raise HTTPError(
|
||||
f"HTTP error occurred: status_code: {resp.status_code} \n "
|
||||
f"code: {resp.code} \n message: {resp.message}",
|
||||
f"HTTP error occurred: status_code: {resp['status_code']} \n "
|
||||
f"code: {resp['code']} \n message: {resp['message']}",
|
||||
response=resp,
|
||||
)
|
||||
|
||||
|
@ -1,11 +1,14 @@
|
||||
"""Test Alibaba Tongyi Chat Model."""
|
||||
from typing import Any, cast
|
||||
|
||||
from typing import Any, List, cast
|
||||
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall, ToolMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr
|
||||
from pytest import CaptureFixture
|
||||
|
||||
from langchain_community.chat_models.tongyi import ChatTongyi
|
||||
@ -138,3 +141,76 @@ def test_multiple_messages() -> None:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
class GenerateUsername(BaseModel):
|
||||
"Get a username based on someone's name and hair color."
|
||||
|
||||
name: str
|
||||
hair_color: str
|
||||
|
||||
|
||||
def test_tool_use() -> None:
|
||||
llm = ChatTongyi(model="qwen-turbo", temperature=0)
|
||||
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
|
||||
msgs: List = [HumanMessage("Sally has green hair, what would her username be?")]
|
||||
ai_msg = llm_with_tool.invoke(msgs)
|
||||
# assert ai_msg is None
|
||||
# ai_msg.content = " "
|
||||
|
||||
assert isinstance(ai_msg, AIMessage)
|
||||
assert isinstance(ai_msg.tool_calls, list)
|
||||
assert len(ai_msg.tool_calls) == 1
|
||||
tool_call = ai_msg.tool_calls[0]
|
||||
assert "args" in tool_call
|
||||
|
||||
tool_msg = ToolMessage(
|
||||
"sally_green_hair",
|
||||
tool_call_id=ai_msg.tool_calls[0]["id"],
|
||||
name=ai_msg.tool_calls[0]["name"],
|
||||
)
|
||||
msgs.extend([ai_msg, tool_msg])
|
||||
llm_with_tool.invoke(msgs)
|
||||
|
||||
# Test streaming
|
||||
ai_messages = llm_with_tool.stream(msgs)
|
||||
first = True
|
||||
for message in ai_messages:
|
||||
if first:
|
||||
gathered = message
|
||||
first = False
|
||||
else:
|
||||
gathered = gathered + message # type: ignore
|
||||
assert isinstance(gathered, AIMessageChunk)
|
||||
|
||||
streaming_tool_msg = ToolMessage(
|
||||
"sally_green_hair",
|
||||
name=tool_call["name"],
|
||||
tool_call_id=tool_call["id"] if tool_call["id"] else " ",
|
||||
)
|
||||
msgs.extend([gathered, streaming_tool_msg])
|
||||
llm_with_tool.invoke(msgs)
|
||||
|
||||
|
||||
def test_manual_tool_call_msg() -> None:
|
||||
"""Test passing in manually construct tool call message."""
|
||||
llm = ChatTongyi(model="qwen-turbo", temperature=0)
|
||||
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
|
||||
msgs: List = [
|
||||
HumanMessage("Sally has green hair, what would her username be?"),
|
||||
AIMessage(
|
||||
content=" ",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
name="GenerateUsername",
|
||||
args={"name": "Sally", "hair_color": "green"},
|
||||
id="foo",
|
||||
)
|
||||
],
|
||||
),
|
||||
ToolMessage("sally_green_hair", tool_call_id="foo"),
|
||||
]
|
||||
output: AIMessage = cast(AIMessage, llm_with_tool.invoke(msgs))
|
||||
assert output.content
|
||||
# Should not have called the tool again.
|
||||
assert not output.tool_calls and not output.invalid_tool_calls
|
||||
|
Loading…
Reference in New Issue
Block a user