core[patch], langchain[patch], templates: move openai functions parsers to core (#18060)

![Screenshot 2024-02-23 at 7 48 03
PM](https://github.com/langchain-ai/langchain/assets/22008038/e5540c4d-0020-4ece-869f-ae19db2a1f3f)
This commit is contained in:
Bagatur
2024-02-26 11:12:53 -08:00
committed by GitHub
parent 96bff0ed5d
commit 767523f364
15 changed files with 264 additions and 252 deletions

View File

@@ -0,0 +1,220 @@
import copy
import json
from typing import Any, Dict, List, Optional, Type, Union
import jsonpatch # type: ignore[import]
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import (
BaseCumulativeTransformOutputParser,
BaseGenerationOutputParser,
)
from langchain_core.output_parsers.json import parse_partial_json
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, root_validator
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
"""Parse an output that is one of sets of values."""
args_only: bool = True
"""Whether to only return the arguments to the function call."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
except KeyError as exc:
raise OutputParserException(f"Could not parse function call: {exc}")
if self.args_only:
return func_call["arguments"]
return func_call
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse an output as the Json object."""
strict: bool = False
"""Whether to allow non-JSON-compliant strings.
See: https://docs.python.org/3/library/json.html#encoders-and-decoders
Useful when the parsed output may include unicode characters or new lines.
"""
args_only: bool = True
"""Whether to only return the arguments to the function call."""
@property
def _type(self) -> str:
return "json_functions"
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
)
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
function_call = message.additional_kwargs["function_call"]
except KeyError as exc:
if partial:
return None
else:
raise OutputParserException(f"Could not parse function call: {exc}")
try:
if partial:
try:
if self.args_only:
return parse_partial_json(
function_call["arguments"], strict=self.strict
)
else:
return {
**function_call,
"arguments": parse_partial_json(
function_call["arguments"], strict=self.strict
),
}
except json.JSONDecodeError:
return None
else:
if self.args_only:
try:
return json.loads(
function_call["arguments"], strict=self.strict
)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
else:
try:
return {
**function_call,
"arguments": json.loads(
function_call["arguments"], strict=self.strict
),
}
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
except KeyError:
return None
# This method would be called by the default implementation of `parse_result`
# but we're overriding that method so it's not needed.
def parse(self, text: str) -> Any:
raise NotImplementedError()
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
"""Parse an output as the element of the Json object."""
key_name: str
"""The name of the key to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
res = super().parse_result(result, partial=partial)
if partial and res is None:
return None
return res.get(self.key_name) if partial else res[self.key_name]
class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""Parse an output as a pydantic object.
This parser is used to parse the output of a ChatModel that uses
OpenAI function format to invoke functions.
The parser extracts the function call invocation and matches
them to the pydantic schema provided.
An exception will be raised if the function call does not match
the provided schema.
Example:
... code-block:: python
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "cookie",
"arguments": json.dumps({"name": "value", "age": 10}),
}
},
)
chat_generation = ChatGeneration(message=message)
class Cookie(BaseModel):
name: str
age: int
class Dog(BaseModel):
species: str
# Full output
parser = PydanticOutputFunctionsParser(
pydantic_schema={"cookie": Cookie, "dog": Dog}
)
result = parser.parse_result([chat_generation])
"""
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
"""The pydantic schema to parse the output with.
If multiple schemas are provided, then the function name will be used to
determine which schema to use.
"""
@root_validator(pre=True)
def validate_schema(cls, values: Dict) -> Dict:
schema = values["pydantic_schema"]
if "args_only" not in values:
values["args_only"] = isinstance(schema, type) and issubclass(
schema, BaseModel
)
elif values["args_only"] and isinstance(schema, Dict):
raise ValueError(
"If multiple pydantic schemas are provided then args_only should be"
" False."
)
return values
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
else:
fn_name = _result["name"]
_args = _result["arguments"]
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore # noqa: E501
return pydantic_args
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
"""Parse an output as an attribute of a pydantic object."""
attr_name: str
"""The name of the attribute to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
result = super().parse_result(result)
return getattr(result, self.attr_name)

View File

@@ -0,0 +1,197 @@
import json
from typing import Any, Dict
import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain_core.outputs import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel
def test_json_output_function_parser() -> None:
"""Test the JSON output function parser is configured with robust defaults."""
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "function_name",
"arguments": '{"arg1": "code\ncode"}',
}
},
)
chat_generation = ChatGeneration(message=message)
# Full output
# Test that the parsers defaults are configured to parse in non-strict mode
parser = JsonOutputFunctionsParser(args_only=False)
result = parser.parse_result([chat_generation])
assert result == {"arguments": {"arg1": "code\ncode"}, "name": "function_name"}
# Args only
parser = JsonOutputFunctionsParser(args_only=True)
result = parser.parse_result([chat_generation])
assert result == {"arg1": "code\ncode"}
# Verify that the original message is not modified
assert message.additional_kwargs == {
"function_call": {
"name": "function_name",
"arguments": '{"arg1": "code\ncode"}',
}
}
@pytest.mark.parametrize(
"config",
[
{
"args_only": False,
"strict": False,
"args": '{"arg1": "value1"}',
"result": {"arguments": {"arg1": "value1"}, "name": "function_name"},
"exception": None,
},
{
"args_only": True,
"strict": False,
"args": '{"arg1": "value1"}',
"result": {"arg1": "value1"},
"exception": None,
},
{
"args_only": True,
"strict": False,
"args": '{"code": "print(2+\n2)"}',
"result": {"code": "print(2+\n2)"},
"exception": None,
},
{
"args_only": True,
"strict": False,
"args": '{"code": "你好)"}',
"result": {"code": "你好)"},
"exception": None,
},
{
"args_only": True,
"strict": True,
"args": '{"code": "print(2+\n2)"}',
"exception": OutputParserException,
},
],
)
def test_json_output_function_parser_strictness(config: Dict[str, Any]) -> None:
"""Test parsing with JSON strictness on and off."""
args = config["args"]
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {"name": "function_name", "arguments": args}
},
)
chat_generation = ChatGeneration(message=message)
# Full output
parser = JsonOutputFunctionsParser(
strict=config["strict"], args_only=config["args_only"]
)
if config["exception"] is not None:
with pytest.raises(config["exception"]):
parser.parse_result([chat_generation])
else:
assert parser.parse_result([chat_generation]) == config["result"]
@pytest.mark.parametrize(
"bad_message",
[
# Human message has no function call
HumanMessage(content="This is a test message"),
# AIMessage has no function call information.
AIMessage(content="This is a test message", additional_kwargs={}),
# Bad function call information (arguments should be a string)
AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {"name": "function_name", "arguments": {}}
},
),
# Bad function call information (arguments should be proper json)
AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {"name": "function_name", "arguments": "noqweqwe"}
},
),
],
)
def test_exceptions_raised_while_parsing(bad_message: BaseMessage) -> None:
"""Test exceptions raised correctly while using JSON parser."""
chat_generation = ChatGeneration(message=bad_message)
with pytest.raises(OutputParserException):
JsonOutputFunctionsParser().parse_result([chat_generation])
def test_pydantic_output_functions_parser() -> None:
"""Test pydantic output functions parser."""
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "function_name",
"arguments": json.dumps({"name": "value", "age": 10}),
}
},
)
chat_generation = ChatGeneration(message=message)
class Model(BaseModel):
"""Test model."""
name: str
age: int
# Full output
parser = PydanticOutputFunctionsParser(pydantic_schema=Model)
result = parser.parse_result([chat_generation])
assert result == Model(name="value", age=10)
def test_pydantic_output_functions_parser_multiple_schemas() -> None:
"""Test that the parser works if providing multiple pydantic schemas."""
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "cookie",
"arguments": json.dumps({"name": "value", "age": 10}),
}
},
)
chat_generation = ChatGeneration(message=message)
class Cookie(BaseModel):
"""Test model."""
name: str
age: int
class Dog(BaseModel):
"""Test model."""
species: str
# Full output
parser = PydanticOutputFunctionsParser(
pydantic_schema={"cookie": Cookie, "dog": Dog}
)
result = parser.parse_result([chat_generation])
assert result == Cookie(name="value", age=10)