diff --git a/libs/community/langchain_community/chat_models/tongyi.py b/libs/community/langchain_community/chat_models/tongyi.py index 7941aa07264..e453e12c1c5 100644 --- a/libs/community/langchain_community/chat_models/tongyi.py +++ b/libs/community/langchain_community/chat_models/tongyi.py @@ -4,6 +4,7 @@ import asyncio import functools import json import logging +from operator import itemgetter from typing import ( Any, AsyncIterator, @@ -40,7 +41,10 @@ from langchain_core.messages import ( ToolMessage, ToolMessageChunk, ) +from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, make_invalid_tool_call, parse_tool_call, ) @@ -50,7 +54,7 @@ from langchain_core.outputs import ( ChatResult, ) from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_core.utils.function_calling import convert_to_openai_tool @@ -372,6 +376,33 @@ class ChatTongyi(BaseChatModel): } ] + Structured output: + .. code-block:: python + + from typing import Optional + + from langchain_core.pydantic_v1 import BaseModel, Field + + + class Joke(BaseModel): + '''Joke to tell user.''' + + setup: str = Field(description="The setup of the joke") + punchline: str = Field(description="The punchline to the joke") + rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10") + + + structured_chat = tongyi_chat.with_structured_output(Joke) + structured_chat.invoke("Tell me a joke about cats") + + .. code-block:: python + + Joke( + setup='Why did the cat join the band?', + punchline='Because it wanted to be a solo purr-sonality!', + rating=None + ) + Response metadata .. code-block:: python @@ -791,3 +822,70 @@ class ChatTongyi(BaseChatModel): formatted_tools = [convert_to_openai_tool(tool) for tool in tools] return super().bind(tools=formatted_tools, **kwargs) + + def with_structured_output( + self, + schema: Union[Dict, Type[BaseModel]], + *, + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: The output schema as a dict or a Pydantic class. If a Pydantic class + then the model output will be an object of that class. If a dict then + the model output will be a dict. With a Pydantic class the returned + attributes will be validated, whereas with a dict they will not be. If + `method` is "function_calling" and `schema` is a dict, then the dict + must match the OpenAI function-calling spec. + include_raw: If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + Returns: + A Runnable that takes any ChatModel input and returns as output: + + If include_raw is True then a dict with keys: + raw: BaseMessage + parsed: Optional[_DictOrPydantic] + parsing_error: Optional[BaseException] + + If include_raw is False then just _DictOrPydantic is returned, + where _DictOrPydantic depends on the schema: + + If schema is a Pydantic class then _DictOrPydantic is the Pydantic + class. + + If schema is a dict then _DictOrPydantic is a dict. + + """ + if kwargs: + raise ValueError(f"Received unsupported arguments {kwargs}") + is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel) + llm = self.bind_tools([schema]) + if is_pydantic_schema: + output_parser: OutputParserLike = PydanticToolsParser( + tools=[schema], # type: ignore[list-item] + first_tool_only=True, # type: ignore[list-item] + ) + else: + key_name = convert_to_openai_tool(schema)["function"]["name"] + output_parser = JsonOutputKeyToolsParser( + key_name=key_name, first_tool_only=True + ) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser diff --git a/libs/community/tests/integration_tests/chat_models/test_tongyi.py b/libs/community/tests/integration_tests/chat_models/test_tongyi.py index 8e0f6c16b54..a395e800c9b 100644 --- a/libs/community/tests/integration_tests/chat_models/test_tongyi.py +++ b/libs/community/tests/integration_tests/chat_models/test_tongyi.py @@ -235,3 +235,34 @@ def test_manual_tool_call_msg() -> None: assert output.content # Should not have called the tool again. assert not output.tool_calls and not output.invalid_tool_calls + + +class AnswerWithJustification(BaseModel): + """An answer to the user question along with justification for the answer.""" + + answer: str + justification: str + + +def test_chat_tongyi_with_structured_output() -> None: + """Test ChatTongyi with structured output.""" + llm = ChatTongyi() # type: ignore + structured_llm = llm.with_structured_output(AnswerWithJustification) + response = structured_llm.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) + assert isinstance(response, AnswerWithJustification) + + +def test_chat_tongyi_with_structured_output_include_raw() -> None: + """Test ChatTongyi with structured output.""" + llm = ChatTongyi() # type: ignore + structured_llm = llm.with_structured_output( + AnswerWithJustification, include_raw=True + ) + response = structured_llm.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) + assert isinstance(response, dict) + assert isinstance(response.get("raw"), AIMessage) + assert isinstance(response.get("parsed"), AnswerWithJustification)