diff --git a/libs/langchain/langchain/output_parsers/openai_functions.py b/libs/langchain/langchain/output_parsers/openai_functions.py index 8551706c4d3..062df8ba967 100644 --- a/libs/langchain/langchain/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/output_parsers/openai_functions.py @@ -136,10 +136,52 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): class PydanticOutputFunctionsParser(OutputFunctionsParser): - """Parse an output as a pydantic object.""" + """Parse an output as a pydantic object. + + This parser is used to parse the output of a ChatModel that uses + OpenAI function format to invoke functions. + + The parser extracts the function call invocation and matches + them to the pydantic schema provided. + + An exception will be raised if the function call does not match + the provided schema. + + Example: + + ... code-block:: python + + message = AIMessage( + content="This is a test message", + additional_kwargs={ + "function_call": { + "name": "cookie", + "arguments": json.dumps({"name": "value", "age": 10}), + } + }, + ) + chat_generation = ChatGeneration(message=message) + + class Cookie(BaseModel): + name: str + age: int + + class Dog(BaseModel): + species: str + + # Full output + parser = PydanticOutputFunctionsParser( + pydantic_schema={"cookie": Cookie, "dog": Dog} + ) + result = parser.parse_result([chat_generation]) + """ pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]] - """The pydantic schema to parse the output with.""" + """The pydantic schema to parse the output with. + + If multiple schemas are provided, then the function name will be used to + determine which schema to use. + """ @root_validator(pre=True) def validate_schema(cls, values: Dict) -> Dict: diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py b/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py index 6af2be93c31..ed1bebf6df5 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict import pytest @@ -7,7 +8,9 @@ from langchain_core.outputs import ChatGeneration from langchain.output_parsers.openai_functions import ( JsonOutputFunctionsParser, + PydanticOutputFunctionsParser, ) +from langchain.pydantic_v1 import BaseModel def test_json_output_function_parser() -> None: @@ -134,3 +137,61 @@ def test_exceptions_raised_while_parsing(bad_message: BaseMessage) -> None: with pytest.raises(OutputParserException): JsonOutputFunctionsParser().parse_result([chat_generation]) + + +def test_pydantic_output_functions_parser() -> None: + """Test pydantic output functions parser.""" + message = AIMessage( + content="This is a test message", + additional_kwargs={ + "function_call": { + "name": "function_name", + "arguments": json.dumps({"name": "value", "age": 10}), + } + }, + ) + chat_generation = ChatGeneration(message=message) + + class Model(BaseModel): + """Test model.""" + + name: str + age: int + + # Full output + parser = PydanticOutputFunctionsParser(pydantic_schema=Model) + result = parser.parse_result([chat_generation]) + assert result == Model(name="value", age=10) + + +def test_pydantic_output_functions_parser_multiple_schemas() -> None: + """Test that the parser works if providing multiple pydantic schemas.""" + + message = AIMessage( + content="This is a test message", + additional_kwargs={ + "function_call": { + "name": "cookie", + "arguments": json.dumps({"name": "value", "age": 10}), + } + }, + ) + chat_generation = ChatGeneration(message=message) + + class Cookie(BaseModel): + """Test model.""" + + name: str + age: int + + class Dog(BaseModel): + """Test model.""" + + species: str + + # Full output + parser = PydanticOutputFunctionsParser( + pydantic_schema={"cookie": Cookie, "dog": Dog} + ) + result = parser.parse_result([chat_generation]) + assert result == Cookie(name="value", age=10)