mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
community: Implement bind_tools
for ChatTongyi (#20725)
## Description Implement `bind_tools` in ChatTongyi. Usage example: ```py from langchain_core.tools import tool from langchain_community.chat_models.tongyi import ChatTongyi @tool def multiply(first_int: int, second_int: int) -> int: """Multiply two integers together.""" return first_int * second_int llm = ChatTongyi(model="qwen-turbo") llm_with_tools = llm.bind_tools([multiply]) msg = llm_with_tools.invoke("What's 5 times forty two") print(msg) ``` Streaming is also supported. ## Dependencies No Dependency is required for this change. --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
b216a1dddb
commit
0ead09f84d
@ -26,14 +26,22 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 1,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false,
|
||||||
"jupyter": {
|
"jupyter": {
|
||||||
"outputs_hidden": false
|
"outputs_hidden": false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# Install the package\n",
|
"# Install the package\n",
|
||||||
"%pip install --upgrade --quiet dashscope"
|
"%pip install --upgrade --quiet dashscope"
|
||||||
@ -48,15 +56,7 @@
|
|||||||
"outputs_hidden": false
|
"outputs_hidden": false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" ········\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# Get a new token: https://help.aliyun.com/document_detail/611472.html?spm=a2c4g.2399481.0.0\n",
|
"# Get a new token: https://help.aliyun.com/document_detail/611472.html?spm=a2c4g.2399481.0.0\n",
|
||||||
"from getpass import getpass\n",
|
"from getpass import getpass\n",
|
||||||
@ -94,8 +94,12 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"chat resp: content='Hello! How' additional_kwargs={} example=False\n",
|
"chat resp: content='Hello' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
|
||||||
"chat resp: content=' can I assist you today?' additional_kwargs={} example=False\n"
|
"chat resp: content='!' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
|
||||||
|
"chat resp: content=' How' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
|
||||||
|
"chat resp: content=' can I assist you today' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
|
||||||
|
"chat resp: content='?' id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n",
|
||||||
|
"chat resp: content='' response_metadata={'finish_reason': 'stop', 'request_id': '921db2c5-4d53-9a89-8e87-e4ad6a671237', 'token_usage': {'input_tokens': 20, 'output_tokens': 9, 'total_tokens': 29}} id='run-f2301962-6d46-423c-8afa-1e667bd11e2b'\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -116,10 +120,18 @@
|
|||||||
"execution_count": 5,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/cheese/PARA/Projects/langchain-contribution/langchain/libs/core/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The method `BaseChatModel.__call__` was deprecated in langchain-core 0.1.7 and will be removed in 0.2.0. Use invoke instead.\n",
|
||||||
|
" warn_deprecated(\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"AIMessageChunk(content=\"J'aime programmer.\", additional_kwargs={}, example=False)"
|
"AIMessage(content=\"J'adore programmer.\", response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'stop', 'request_id': 'ae725086-0ffa-9728-8c72-b204c7bc7eeb', 'token_usage': {'input_tokens': 36, 'output_tokens': 6, 'total_tokens': 42}}, id='run-060cc103-ef5f-4c8a-af40-792ac7f40c26-0')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 5,
|
||||||
@ -149,18 +161,65 @@
|
|||||||
"ChatTongyi supports tool calling API that lets you describe tools and their arguments, and have the model return a JSON object with a tool to invoke and the inputs to that tool."
|
"ChatTongyi supports tool calling API that lets you describe tools and their arguments, and have the model return a JSON object with a tool to invoke and the inputs to that tool."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Use with `bind_tools`"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"content='' additional_kwargs={'tool_calls': [{'function': {'name': 'multiply', 'arguments': '{\"first_int\": 5, \"second_int\": 42}'}, 'id': '', 'type': 'function'}]} response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': '4acf0e36-44af-987a-a0c0-8b5c5eaa1a8b', 'token_usage': {'input_tokens': 200, 'output_tokens': 25, 'total_tokens': 225}} id='run-0ecd0f09-1d20-4e55-a4f3-f14d1f710ae7-0' tool_calls=[{'name': 'multiply', 'args': {'first_int': 5, 'second_int': 42}, 'id': ''}]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.chat_models.tongyi import ChatTongyi\n",
|
||||||
|
"from langchain_core.tools import tool\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"@tool\n",
|
||||||
|
"def multiply(first_int: int, second_int: int) -> int:\n",
|
||||||
|
" \"\"\"Multiply two integers together.\"\"\"\n",
|
||||||
|
" return first_int * second_int\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatTongyi(model=\"qwen-turbo\")\n",
|
||||||
|
"\n",
|
||||||
|
"llm_with_tools = llm.bind_tools([multiply])\n",
|
||||||
|
"\n",
|
||||||
|
"msg = llm_with_tools.invoke(\"What's 5 times forty two\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(msg)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Construct args manually"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"AIMessage(content='', additional_kwargs={'tool_calls': [{'function': {'name': 'get_current_weather', 'arguments': '{\"location\": \"San Francisco\"}'}, 'id': '', 'type': 'function'}]}, response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': 'dae79197-8780-9b7e-8c15-6a83e2a53534', 'token_usage': {'input_tokens': 229, 'output_tokens': 19, 'total_tokens': 248}}, id='run-9e06f837-582b-473b-bb1f-5e99a68ecc10-0', tool_calls=[{'name': 'get_current_weather', 'args': {'location': 'San Francisco'}, 'id': ''}])"
|
"AIMessage(content='', additional_kwargs={'tool_calls': [{'function': {'name': 'get_current_weather', 'arguments': '{\"location\": \"San Francisco\"}'}, 'id': '', 'type': 'function'}]}, response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': '87ef33d2-5c6b-9457-91e2-39faad7120eb', 'token_usage': {'input_tokens': 229, 'output_tokens': 19, 'total_tokens': 248}}, id='run-7939ba7f-e3f7-46f8-980b-30499b52723c-0', tool_calls=[{'name': 'get_current_weather', 'args': {'location': 'San Francisco'}, 'id': ''}])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -224,7 +283,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.12"
|
"version": "3.12.2"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -12,6 +13,8 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Type,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -20,6 +23,7 @@ from langchain_core.callbacks import (
|
|||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -32,6 +36,8 @@ from langchain_core.messages import (
|
|||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
|
ToolMessage,
|
||||||
|
ToolMessageChunk,
|
||||||
)
|
)
|
||||||
from langchain_core.output_parsers.openai_tools import (
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
make_invalid_tool_call,
|
make_invalid_tool_call,
|
||||||
@ -42,8 +48,11 @@ from langchain_core.outputs import (
|
|||||||
ChatGenerationChunk,
|
ChatGenerationChunk,
|
||||||
ChatResult,
|
ChatResult,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
@ -68,6 +77,7 @@ def convert_dict_to_message(
|
|||||||
"""Convert a dict to a message."""
|
"""Convert a dict to a message."""
|
||||||
role = _dict["role"]
|
role = _dict["role"]
|
||||||
content = _dict["content"]
|
content = _dict["content"]
|
||||||
|
|
||||||
if role == "user":
|
if role == "user":
|
||||||
return (
|
return (
|
||||||
HumanMessageChunk(content=content)
|
HumanMessageChunk(content=content)
|
||||||
@ -79,17 +89,39 @@ def convert_dict_to_message(
|
|||||||
invalid_tool_calls = []
|
invalid_tool_calls = []
|
||||||
if "tool_calls" in _dict:
|
if "tool_calls" in _dict:
|
||||||
additional_kwargs = {"tool_calls": _dict["tool_calls"]}
|
additional_kwargs = {"tool_calls": _dict["tool_calls"]}
|
||||||
for raw_tool_call in _dict["tool_calls"]:
|
|
||||||
try:
|
for index, value in enumerate(_dict["tool_calls"]):
|
||||||
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
|
if is_chunk:
|
||||||
except Exception as e:
|
try:
|
||||||
invalid_tool_calls.append(
|
tool_calls.append(
|
||||||
make_invalid_tool_call(raw_tool_call, str(e))
|
{
|
||||||
)
|
"name": value["function"].get("name"),
|
||||||
|
"args": value["function"].get("arguments"),
|
||||||
|
"id": value.get("id"),
|
||||||
|
# Tongyi does not respond with index,
|
||||||
|
# use index in the list instead
|
||||||
|
"index": index,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
parsed_tool = parse_tool_call(value, return_id=True)
|
||||||
|
if parsed_tool:
|
||||||
|
tool_calls.append(parsed_tool)
|
||||||
|
except Exception as e:
|
||||||
|
invalid_tool_calls.append(make_invalid_tool_call(value, str(e)))
|
||||||
else:
|
else:
|
||||||
additional_kwargs = {}
|
additional_kwargs = {}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
AIMessageChunk(content=content)
|
AIMessageChunk(
|
||||||
|
content=content,
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
tool_call_chunks=tool_calls,
|
||||||
|
id=_dict.get("id"),
|
||||||
|
)
|
||||||
if is_chunk
|
if is_chunk
|
||||||
else AIMessage(
|
else AIMessage(
|
||||||
content=content,
|
content=content,
|
||||||
@ -104,6 +136,23 @@ def convert_dict_to_message(
|
|||||||
if is_chunk
|
if is_chunk
|
||||||
else SystemMessage(content=content)
|
else SystemMessage(content=content)
|
||||||
)
|
)
|
||||||
|
elif role == "tool":
|
||||||
|
additional_kwargs = {}
|
||||||
|
if "name" in _dict:
|
||||||
|
additional_kwargs["name"] = _dict["name"]
|
||||||
|
return (
|
||||||
|
ToolMessageChunk(
|
||||||
|
content=_dict.get("content", ""),
|
||||||
|
tool_call_id=_dict.get("tool_call_id"),
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
)
|
||||||
|
if is_chunk
|
||||||
|
else ToolMessage(
|
||||||
|
content=_dict.get("content", ""),
|
||||||
|
tool_call_id=_dict.get("tool_call_id"),
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
ChatMessageChunk(role=role, content=content)
|
ChatMessageChunk(role=role, content=content)
|
||||||
@ -113,17 +162,23 @@ def convert_dict_to_message(
|
|||||||
|
|
||||||
|
|
||||||
def convert_message_chunk_to_message(message_chunk: BaseMessageChunk) -> BaseMessage:
|
def convert_message_chunk_to_message(message_chunk: BaseMessageChunk) -> BaseMessage:
|
||||||
"""Convert a message chunk to a message."""
|
"""Convert a message chunk to a message.
|
||||||
if isinstance(message_chunk, HumanMessageChunk):
|
|
||||||
return HumanMessage(content=message_chunk.content)
|
Args:
|
||||||
elif isinstance(message_chunk, AIMessageChunk):
|
chunk: Message chunk to convert.
|
||||||
return AIMessage(content=message_chunk.content)
|
|
||||||
elif isinstance(message_chunk, SystemMessageChunk):
|
Returns:
|
||||||
return SystemMessage(content=message_chunk.content)
|
Message.
|
||||||
elif isinstance(message_chunk, ChatMessageChunk):
|
"""
|
||||||
return ChatMessage(role=message_chunk.role, content=message_chunk.content)
|
if not isinstance(message_chunk, BaseMessageChunk):
|
||||||
else:
|
return message_chunk
|
||||||
raise TypeError(f"Got unknown type {message_chunk}")
|
# chunk classes always have the equivalent non-chunk class as their first parent
|
||||||
|
ignore_keys = ["type"]
|
||||||
|
if isinstance(message_chunk, AIMessageChunk):
|
||||||
|
ignore_keys.append("tool_call_chunks")
|
||||||
|
return message_chunk.__class__.__mro__[1](
|
||||||
|
**{k: v for k, v in message_chunk.__dict__.items() if k not in ignore_keys}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
@ -136,8 +191,17 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "user", "content": message.content}
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
if "tool_calls" in message.additional_kwargs:
|
||||||
|
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
message_dict = {"role": "system", "content": message.content}
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, ToolMessage):
|
||||||
|
message_dict = {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": message.tool_call_id,
|
||||||
|
"content": message.content,
|
||||||
|
"name": message.name,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Got unknown type {message}")
|
raise TypeError(f"Got unknown type {message}")
|
||||||
return message_dict
|
return message_dict
|
||||||
@ -256,11 +320,57 @@ class ChatTongyi(BaseChatModel):
|
|||||||
@retry_decorator
|
@retry_decorator
|
||||||
def _stream_completion_with_retry(**_kwargs: Any) -> Any:
|
def _stream_completion_with_retry(**_kwargs: Any) -> Any:
|
||||||
responses = self.client.call(**_kwargs)
|
responses = self.client.call(**_kwargs)
|
||||||
|
prev_resp = None
|
||||||
|
|
||||||
for resp in responses:
|
for resp in responses:
|
||||||
yield check_response(resp)
|
# If we are streaming without `incremental_output = True`,
|
||||||
|
# we need to calculate the delta response manually
|
||||||
|
if _kwargs.get("stream") and not _kwargs.get(
|
||||||
|
"incremental_output", False
|
||||||
|
):
|
||||||
|
if prev_resp is None:
|
||||||
|
delta_resp = resp
|
||||||
|
else:
|
||||||
|
delta_resp = self.subtract_client_response(resp, prev_resp)
|
||||||
|
prev_resp = resp
|
||||||
|
yield check_response(delta_resp)
|
||||||
|
else:
|
||||||
|
yield check_response(resp)
|
||||||
|
|
||||||
return _stream_completion_with_retry(**kwargs)
|
return _stream_completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
def subtract_client_response(self, resp: Any, prev_resp: Any) -> Any:
|
||||||
|
"""Subtract prev response from curr response.
|
||||||
|
|
||||||
|
Useful when streaming without `incremental_output = True`
|
||||||
|
"""
|
||||||
|
|
||||||
|
resp_copy = json.loads(json.dumps(resp))
|
||||||
|
choice = resp_copy["output"]["choices"][0]
|
||||||
|
message = choice["message"]
|
||||||
|
|
||||||
|
prev_resp_copy = json.loads(json.dumps(prev_resp))
|
||||||
|
prev_choice = prev_resp_copy["output"]["choices"][0]
|
||||||
|
prev_message = prev_choice["message"]
|
||||||
|
|
||||||
|
message["content"] = message["content"].replace(prev_message["content"], "")
|
||||||
|
|
||||||
|
if message.get("tool_calls"):
|
||||||
|
for index, tool_call in enumerate(message["tool_calls"]):
|
||||||
|
function = tool_call["function"]
|
||||||
|
|
||||||
|
if prev_message.get("tool_calls"):
|
||||||
|
prev_function = prev_message["tool_calls"][index]["function"]
|
||||||
|
|
||||||
|
function["name"] = function["name"].replace(
|
||||||
|
prev_function["name"], ""
|
||||||
|
)
|
||||||
|
function["arguments"] = function["arguments"].replace(
|
||||||
|
prev_function["arguments"], ""
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp_copy
|
||||||
|
|
||||||
async def astream_completion_with_retry(self, **kwargs: Any) -> Any:
|
async def astream_completion_with_retry(self, **kwargs: Any) -> Any:
|
||||||
"""Because the dashscope SDK doesn't provide an async API,
|
"""Because the dashscope SDK doesn't provide an async API,
|
||||||
we wrap `stream_generate_with_retry` with an async generator."""
|
we wrap `stream_generate_with_retry` with an async generator."""
|
||||||
@ -301,16 +411,16 @@ class ChatTongyi(BaseChatModel):
|
|||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
generations = []
|
generations = []
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
generation: Optional[ChatGenerationChunk] = None
|
generation_chunk: Optional[ChatGenerationChunk] = None
|
||||||
for chunk in self._stream(
|
for chunk in self._stream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
):
|
):
|
||||||
if generation is None:
|
if generation_chunk is None:
|
||||||
generation = chunk
|
generation_chunk = chunk
|
||||||
else:
|
else:
|
||||||
generation += chunk
|
generation_chunk += chunk
|
||||||
assert generation is not None
|
assert generation_chunk is not None
|
||||||
generations.append(self._chunk_to_generation(generation))
|
generations.append(self._chunk_to_generation(generation_chunk))
|
||||||
else:
|
else:
|
||||||
params: Dict[str, Any] = self._invocation_params(
|
params: Dict[str, Any] = self._invocation_params(
|
||||||
messages=messages, stop=stop, **kwargs
|
messages=messages, stop=stop, **kwargs
|
||||||
@ -373,9 +483,19 @@ class ChatTongyi(BaseChatModel):
|
|||||||
params: Dict[str, Any] = self._invocation_params(
|
params: Dict[str, Any] = self._invocation_params(
|
||||||
messages=messages, stop=stop, stream=True, **kwargs
|
messages=messages, stop=stop, stream=True, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
for stream_resp, is_last_chunk in generate_with_last_element_mark(
|
for stream_resp, is_last_chunk in generate_with_last_element_mark(
|
||||||
self.stream_completion_with_retry(**params)
|
self.stream_completion_with_retry(**params)
|
||||||
):
|
):
|
||||||
|
choice = stream_resp["output"]["choices"][0]
|
||||||
|
message = choice["message"]
|
||||||
|
if (
|
||||||
|
choice["finish_reason"] == "null"
|
||||||
|
and message["content"] == ""
|
||||||
|
and "tool_calls" not in message
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
**self._chat_generation_from_qwen_resp(
|
**self._chat_generation_from_qwen_resp(
|
||||||
stream_resp, is_chunk=True, is_last_chunk=is_last_chunk
|
stream_resp, is_chunk=True, is_last_chunk=is_last_chunk
|
||||||
@ -413,14 +533,13 @@ class ChatTongyi(BaseChatModel):
|
|||||||
params = {**self._default_params, **kwargs}
|
params = {**self._default_params, **kwargs}
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
if params.get("stream"):
|
# According to the Tongyi official docs,
|
||||||
|
# `incremental_output` with `tools` is not supported yet
|
||||||
|
if params.get("stream") and not params.get("tools"):
|
||||||
params["incremental_output"] = True
|
params["incremental_output"] = True
|
||||||
|
|
||||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||||
|
|
||||||
# According to the docs, the last message should be a `user` message
|
|
||||||
if message_dicts[-1]["role"] != "user":
|
|
||||||
raise ValueError("Last message should be user message.")
|
|
||||||
# And the `system` message should be the first message if present
|
# And the `system` message should be the first message if present
|
||||||
system_message_indices = [
|
system_message_indices = [
|
||||||
i for i, m in enumerate(message_dicts) if m["role"] == "system"
|
i for i, m in enumerate(message_dicts) if m["role"] == "system"
|
||||||
@ -470,3 +589,22 @@ class ChatTongyi(BaseChatModel):
|
|||||||
message=convert_message_chunk_to_message(chunk.message),
|
message=convert_message_chunk_to_message(chunk.message),
|
||||||
generation_info=chunk.generation_info,
|
generation_info=chunk.generation_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
"""Bind tool-like objects to this chat model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: A list of tool definitions to bind to this chat model.
|
||||||
|
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||||
|
models, callables, and BaseTools will be automatically converted to
|
||||||
|
their schema dictionary representation.
|
||||||
|
**kwargs: Any additional parameters to pass to the
|
||||||
|
:class:`~langchain.runnable.Runnable` constructor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||||
|
return super().bind(tools=formatted_tools, **kwargs)
|
||||||
|
@ -55,17 +55,17 @@ def _create_retry_decorator(llm: Tongyi) -> Callable[[Any], Any]:
|
|||||||
|
|
||||||
def check_response(resp: Any) -> Any:
|
def check_response(resp: Any) -> Any:
|
||||||
"""Check the response from the completion call."""
|
"""Check the response from the completion call."""
|
||||||
if resp.status_code == 200:
|
if resp["status_code"] == 200:
|
||||||
return resp
|
return resp
|
||||||
elif resp.status_code in [400, 401]:
|
elif resp["status_code"] in [400, 401]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"status_code: {resp.status_code} \n "
|
f"status_code: {resp['status_code']} \n "
|
||||||
f"code: {resp.code} \n message: {resp.message}"
|
f"code: {resp['code']} \n message: {resp['message']}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPError(
|
raise HTTPError(
|
||||||
f"HTTP error occurred: status_code: {resp.status_code} \n "
|
f"HTTP error occurred: status_code: {resp['status_code']} \n "
|
||||||
f"code: {resp.code} \n message: {resp.message}",
|
f"code: {resp['code']} \n message: {resp['message']}",
|
||||||
response=resp,
|
response=resp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
"""Test Alibaba Tongyi Chat Model."""
|
"""Test Alibaba Tongyi Chat Model."""
|
||||||
from typing import Any, cast
|
|
||||||
|
from typing import Any, List, cast
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManager
|
from langchain_core.callbacks import CallbackManager
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
from langchain_core.messages.ai import AIMessageChunk
|
||||||
|
from langchain_core.messages.tool import ToolCall, ToolMessage
|
||||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
from langchain_core.pydantic_v1 import SecretStr
|
from langchain_core.pydantic_v1 import BaseModel, SecretStr
|
||||||
from pytest import CaptureFixture
|
from pytest import CaptureFixture
|
||||||
|
|
||||||
from langchain_community.chat_models.tongyi import ChatTongyi
|
from langchain_community.chat_models.tongyi import ChatTongyi
|
||||||
@ -138,3 +141,76 @@ def test_multiple_messages() -> None:
|
|||||||
assert isinstance(generation, ChatGeneration)
|
assert isinstance(generation, ChatGeneration)
|
||||||
assert isinstance(generation.text, str)
|
assert isinstance(generation.text, str)
|
||||||
assert generation.text == generation.message.content
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateUsername(BaseModel):
|
||||||
|
"Get a username based on someone's name and hair color."
|
||||||
|
|
||||||
|
name: str
|
||||||
|
hair_color: str
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_use() -> None:
|
||||||
|
llm = ChatTongyi(model="qwen-turbo", temperature=0)
|
||||||
|
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
|
||||||
|
msgs: List = [HumanMessage("Sally has green hair, what would her username be?")]
|
||||||
|
ai_msg = llm_with_tool.invoke(msgs)
|
||||||
|
# assert ai_msg is None
|
||||||
|
# ai_msg.content = " "
|
||||||
|
|
||||||
|
assert isinstance(ai_msg, AIMessage)
|
||||||
|
assert isinstance(ai_msg.tool_calls, list)
|
||||||
|
assert len(ai_msg.tool_calls) == 1
|
||||||
|
tool_call = ai_msg.tool_calls[0]
|
||||||
|
assert "args" in tool_call
|
||||||
|
|
||||||
|
tool_msg = ToolMessage(
|
||||||
|
"sally_green_hair",
|
||||||
|
tool_call_id=ai_msg.tool_calls[0]["id"],
|
||||||
|
name=ai_msg.tool_calls[0]["name"],
|
||||||
|
)
|
||||||
|
msgs.extend([ai_msg, tool_msg])
|
||||||
|
llm_with_tool.invoke(msgs)
|
||||||
|
|
||||||
|
# Test streaming
|
||||||
|
ai_messages = llm_with_tool.stream(msgs)
|
||||||
|
first = True
|
||||||
|
for message in ai_messages:
|
||||||
|
if first:
|
||||||
|
gathered = message
|
||||||
|
first = False
|
||||||
|
else:
|
||||||
|
gathered = gathered + message # type: ignore
|
||||||
|
assert isinstance(gathered, AIMessageChunk)
|
||||||
|
|
||||||
|
streaming_tool_msg = ToolMessage(
|
||||||
|
"sally_green_hair",
|
||||||
|
name=tool_call["name"],
|
||||||
|
tool_call_id=tool_call["id"] if tool_call["id"] else " ",
|
||||||
|
)
|
||||||
|
msgs.extend([gathered, streaming_tool_msg])
|
||||||
|
llm_with_tool.invoke(msgs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_manual_tool_call_msg() -> None:
|
||||||
|
"""Test passing in manually construct tool call message."""
|
||||||
|
llm = ChatTongyi(model="qwen-turbo", temperature=0)
|
||||||
|
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
|
||||||
|
msgs: List = [
|
||||||
|
HumanMessage("Sally has green hair, what would her username be?"),
|
||||||
|
AIMessage(
|
||||||
|
content=" ",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
name="GenerateUsername",
|
||||||
|
args={"name": "Sally", "hair_color": "green"},
|
||||||
|
id="foo",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ToolMessage("sally_green_hair", tool_call_id="foo"),
|
||||||
|
]
|
||||||
|
output: AIMessage = cast(AIMessage, llm_with_tool.invoke(msgs))
|
||||||
|
assert output.content
|
||||||
|
# Should not have called the tool again.
|
||||||
|
assert not output.tool_calls and not output.invalid_tool_calls
|
||||||
|
Loading…
Reference in New Issue
Block a user