mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
feat: support ChatModels Qianfan QianfanChatEndpoint
function_call (#11107)
- **Description:** * feature for `QianfanChatEndpoint` function_call ability, add integration_test for it * add `model`, `endpoint` supported in calling params * add raw response in ChatModel Message - **Issue:** * #10867 * #11105 * #10215 - **Dependencies:** no - **Tag maintainer:** @baskaryan - **Twitter handle:** no
This commit is contained in:
parent
67300567d3
commit
b647505280
@ -22,7 +22,6 @@ from langchain.schema.messages import (
|
|||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
BaseMessageChunk,
|
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
FunctionMessage,
|
FunctionMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
@ -34,13 +33,6 @@ from langchain.utils import get_from_dict_or_env
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk:
|
|
||||||
return AIMessageChunk(
|
|
||||||
content=resp["result"],
|
|
||||||
role="assistant",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
"""Convert a message to a dictionary that can be passed to the API."""
|
"""Convert a message to a dictionary that can be passed to the API."""
|
||||||
message_dict: Dict[str, Any]
|
message_dict: Dict[str, Any]
|
||||||
@ -51,7 +43,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
if "function_call" in message.additional_kwargs:
|
if "function_call" in message.additional_kwargs:
|
||||||
message_dict["functions"] = message.additional_kwargs["function_call"]
|
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||||
# If function call only, content is None not empty string
|
# If function call only, content is None not empty string
|
||||||
if message_dict["content"] == "":
|
if message_dict["content"] == "":
|
||||||
message_dict["content"] = None
|
message_dict["content"] = None
|
||||||
@ -67,6 +59,21 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
|
||||||
|
content = _dict.get("result", "") or ""
|
||||||
|
if _dict.get("function_call"):
|
||||||
|
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||||
|
if "thoughts" in additional_kwargs["function_call"]:
|
||||||
|
# align to api sample, which affects the llm function_call output
|
||||||
|
additional_kwargs["function_call"].pop("thoughts")
|
||||||
|
else:
|
||||||
|
additional_kwargs = {}
|
||||||
|
return AIMessage(
|
||||||
|
content=content,
|
||||||
|
additional_kwargs={**_dict.get("body", {}), **additional_kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QianfanChatEndpoint(BaseChatModel):
|
class QianfanChatEndpoint(BaseChatModel):
|
||||||
"""Baidu Qianfan chat models.
|
"""Baidu Qianfan chat models.
|
||||||
|
|
||||||
@ -164,6 +171,8 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
"""Get the default parameters for calling OpenAI API."""
|
"""Get the default parameters for calling OpenAI API."""
|
||||||
normal_params = {
|
normal_params = {
|
||||||
|
"model": self.model,
|
||||||
|
"endpoint": self.endpoint,
|
||||||
"stream": self.streaming,
|
"stream": self.streaming,
|
||||||
"request_timeout": self.request_timeout,
|
"request_timeout": self.request_timeout,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
@ -243,10 +252,13 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
)
|
)
|
||||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
response_payload = self.client.do(**params)
|
response_payload = self.client.do(**params)
|
||||||
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
|
lc_msg = _convert_dict_to_message(response_payload)
|
||||||
gen = ChatGeneration(
|
gen = ChatGeneration(
|
||||||
message=lc_msg,
|
message=lc_msg,
|
||||||
generation_info=dict(finish_reason="stop"),
|
generation_info={
|
||||||
|
"finish_reason": "stop",
|
||||||
|
**response_payload.get("body", {}),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
token_usage = response_payload.get("usage", {})
|
token_usage = response_payload.get("usage", {})
|
||||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||||
@ -276,11 +288,14 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
)
|
)
|
||||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
response_payload = await self.client.ado(**params)
|
response_payload = await self.client.ado(**params)
|
||||||
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
|
lc_msg = _convert_dict_to_message(response_payload)
|
||||||
generations = []
|
generations = []
|
||||||
gen = ChatGeneration(
|
gen = ChatGeneration(
|
||||||
message=lc_msg,
|
message=lc_msg,
|
||||||
generation_info=dict(finish_reason="stop"),
|
generation_info={
|
||||||
|
"finish_reason": "stop",
|
||||||
|
**response_payload.get("body", {}),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
generations.append(gen)
|
generations.append(gen)
|
||||||
token_usage = response_payload.get("usage", {})
|
token_usage = response_payload.get("usage", {})
|
||||||
@ -297,9 +312,14 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
for res in self.client.do(**params):
|
for res in self.client.do(**params):
|
||||||
if res:
|
if res:
|
||||||
|
msg = _convert_dict_to_message(res)
|
||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
text=res["result"],
|
text=res["result"],
|
||||||
message=_convert_resp_to_message_chunk(res),
|
message=AIMessageChunk(
|
||||||
|
content=msg.content,
|
||||||
|
role="assistant",
|
||||||
|
additional_kwargs=msg.additional_kwargs,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
@ -315,9 +335,14 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
async for res in await self.client.ado(**params):
|
async for res in await self.client.ado(**params):
|
||||||
if res:
|
if res:
|
||||||
|
msg = _convert_dict_to_message(res)
|
||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
text=res["result"],
|
text=res["result"],
|
||||||
message=_convert_resp_to_message_chunk(res),
|
message=AIMessageChunk(
|
||||||
|
content=msg.content,
|
||||||
|
role="assistant",
|
||||||
|
additional_kwargs=msg.additional_kwargs,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
|
@ -118,6 +118,8 @@ class QianfanLLMEndpoint(LLM):
|
|||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
"""Get the default parameters for calling OpenAI API."""
|
"""Get the default parameters for calling OpenAI API."""
|
||||||
normal_params = {
|
normal_params = {
|
||||||
|
"model": self.model,
|
||||||
|
"endpoint": self.endpoint,
|
||||||
"stream": self.streaming,
|
"stream": self.streaming,
|
||||||
"request_timeout": self.request_timeout,
|
"request_timeout": self.request_timeout,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
|
@ -1,16 +1,87 @@
|
|||||||
"""Test Baidu Qianfan Chat Endpoint."""
|
"""Test Baidu Qianfan Chat Endpoint."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
from langchain.chains.openai_functions import (
|
||||||
|
create_openai_fn_chain,
|
||||||
|
)
|
||||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||||
|
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
|
FunctionMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
LLMResult,
|
LLMResult,
|
||||||
)
|
)
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
_FUNCTIONS: Any = [
|
||||||
|
{
|
||||||
|
"name": "format_person_info",
|
||||||
|
"description": (
|
||||||
|
"Output formatter. Should always be used to format your response to the"
|
||||||
|
" user."
|
||||||
|
),
|
||||||
|
"parameters": {
|
||||||
|
"title": "Person",
|
||||||
|
"description": "Identifying information about a person.",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"title": "Name",
|
||||||
|
"description": "The person's name",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"age": {
|
||||||
|
"title": "Age",
|
||||||
|
"description": "The person's age",
|
||||||
|
"type": "integer",
|
||||||
|
},
|
||||||
|
"fav_food": {
|
||||||
|
"title": "Fav Food",
|
||||||
|
"description": "The person's favorite food",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["name", "age"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "get_current_temperature",
|
||||||
|
"description": ("Used to get the location's temperature."),
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "city name",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["centigrade", "Fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "unit"],
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"temperature": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "city temperature",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["centigrade", "Fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_default_call() -> None:
|
def test_default_call() -> None:
|
||||||
"""Test default model(`ERNIE-Bot`) call."""
|
"""Test default model(`ERNIE-Bot`) call."""
|
||||||
@ -28,6 +99,14 @@ def test_model() -> None:
|
|||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_param() -> None:
|
||||||
|
"""Test model params works."""
|
||||||
|
chat = QianfanChatEndpoint()
|
||||||
|
response = chat(model="BLOOMZ-7B", messages=[HumanMessage(content="Hello")])
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
def test_endpoint() -> None:
|
def test_endpoint() -> None:
|
||||||
"""Test user custom model deployments like some open source models."""
|
"""Test user custom model deployments like some open source models."""
|
||||||
chat = QianfanChatEndpoint(endpoint="qianfan_bloomz_7b_compressed")
|
chat = QianfanChatEndpoint(endpoint="qianfan_bloomz_7b_compressed")
|
||||||
@ -36,6 +115,18 @@ def test_endpoint() -> None:
|
|||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_endpoint_param() -> None:
|
||||||
|
"""Test user custom model deployments like some open source models."""
|
||||||
|
chat = QianfanChatEndpoint()
|
||||||
|
response = chat(
|
||||||
|
messages=[
|
||||||
|
HumanMessage(endpoint="qianfan_bloomz_7b_compressed", content="Hello")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_history() -> None:
|
def test_multiple_history() -> None:
|
||||||
"""Tests multiple history works."""
|
"""Tests multiple history works."""
|
||||||
chat = QianfanChatEndpoint()
|
chat = QianfanChatEndpoint()
|
||||||
@ -83,3 +174,60 @@ def test_multiple_messages() -> None:
|
|||||||
assert isinstance(generation, ChatGeneration)
|
assert isinstance(generation, ChatGeneration)
|
||||||
assert isinstance(generation.text, str)
|
assert isinstance(generation.text, str)
|
||||||
assert generation.text == generation.message.content
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_functions_call_thoughts() -> None:
|
||||||
|
chat = QianfanChatEndpoint(model="ERNIE-Bot")
|
||||||
|
|
||||||
|
prompt_tmpl = "Use the given functions to answer following question: {input}"
|
||||||
|
prompt_msgs = [
|
||||||
|
HumanMessagePromptTemplate.from_template(prompt_tmpl),
|
||||||
|
]
|
||||||
|
prompt = ChatPromptTemplate(messages=prompt_msgs)
|
||||||
|
|
||||||
|
chain = create_openai_fn_chain(
|
||||||
|
_FUNCTIONS,
|
||||||
|
chat,
|
||||||
|
prompt,
|
||||||
|
output_parser=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = HumanMessage(content="What's the temperature in Shanghai today?")
|
||||||
|
response = chain.generate([{"input": message}])
|
||||||
|
assert isinstance(response.generations[0][0], ChatGeneration)
|
||||||
|
assert isinstance(response.generations[0][0].message, AIMessage)
|
||||||
|
assert "function_call" in response.generations[0][0].message.additional_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_functions_call() -> None:
|
||||||
|
chat = QianfanChatEndpoint(model="ERNIE-Bot")
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate(
|
||||||
|
messages=[
|
||||||
|
HumanMessage(content="What's the temperature in Shanghai today?"),
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": {
|
||||||
|
"name": "get_current_temperature",
|
||||||
|
"thoughts": "i will use get_current_temperature "
|
||||||
|
"to resolve the questions",
|
||||||
|
"arguments": '{"location":"Shanghai","unit":"centigrade"}',
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
FunctionMessage(
|
||||||
|
name="get_current_weather",
|
||||||
|
content='{"temperature": "25", \
|
||||||
|
"unit": "摄氏度", "description": "晴朗"}',
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
llm_chain = create_openai_fn_chain(
|
||||||
|
_FUNCTIONS,
|
||||||
|
chat,
|
||||||
|
prompt,
|
||||||
|
output_parser=None,
|
||||||
|
)
|
||||||
|
resp = llm_chain.generate([{}])
|
||||||
|
assert isinstance(resp, LLMResult)
|
||||||
|
Loading…
Reference in New Issue
Block a user