feat: Add ChatTongyi structured output (#24187)

- **Description:** Add `with_structured_output` method to ChatTongyi to
support structured output.
This commit is contained in:
maang-h 2024-07-16 03:57:21 +08:00 committed by GitHub
parent 8f4620f4b8
commit 6c7d9f93b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 130 additions and 1 deletions

View File

@ -4,6 +4,7 @@ import asyncio
import functools
import json
import logging
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
@ -40,7 +41,10 @@ from langchain_core.messages import (
ToolMessage,
ToolMessageChunk,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
@ -50,7 +54,7 @@ from langchain_core.outputs import (
ChatResult,
)
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.runnables import Runnable
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool
@ -372,6 +376,33 @@ class ChatTongyi(BaseChatModel):
}
]
Structured output:
.. code-block:: python
from typing import Optional
from langchain_core.pydantic_v1 import BaseModel, Field
class Joke(BaseModel):
'''Joke to tell user.'''
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline to the joke")
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
structured_chat = tongyi_chat.with_structured_output(Joke)
structured_chat.invoke("Tell me a joke about cats")
.. code-block:: python
Joke(
setup='Why did the cat join the band?',
punchline='Because it wanted to be a solo purr-sonality!',
rating=None
)
Response metadata
.. code-block:: python
@ -791,3 +822,70 @@ class ChatTongyi(BaseChatModel):
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OpenAI function-calling spec.
include_raw: If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes any ChatModel input and returns as output:
If include_raw is True then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
"""
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
llm = self.bind_tools([schema])
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser

View File

@ -235,3 +235,34 @@ def test_manual_tool_call_msg() -> None:
assert output.content
# Should not have called the tool again.
assert not output.tool_calls and not output.invalid_tool_calls
class AnswerWithJustification(BaseModel):
"""An answer to the user question along with justification for the answer."""
answer: str
justification: str
def test_chat_tongyi_with_structured_output() -> None:
"""Test ChatTongyi with structured output."""
llm = ChatTongyi() # type: ignore
structured_llm = llm.with_structured_output(AnswerWithJustification)
response = structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
assert isinstance(response, AnswerWithJustification)
def test_chat_tongyi_with_structured_output_include_raw() -> None:
"""Test ChatTongyi with structured output."""
llm = ChatTongyi() # type: ignore
structured_llm = llm.with_structured_output(
AnswerWithJustification, include_raw=True
)
response = structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
assert isinstance(response, dict)
assert isinstance(response.get("raw"), AIMessage)
assert isinstance(response.get("parsed"), AnswerWithJustification)