Compare commits

...

2 Commits

Author SHA1 Message Date
Bagatur
d3043a627e tests 2024-02-26 16:26:47 -08:00
Bagatur
60bd43053d core[patch]: support pydantic v2 context 2024-02-26 16:05:19 -08:00
4 changed files with 98 additions and 25 deletions

View File

@@ -1,6 +1,6 @@
import copy
import json
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type, Union, cast
import jsonpatch # type: ignore[import]
@@ -11,7 +11,7 @@ from langchain_core.output_parsers import (
)
from langchain_core.output_parsers.json import parse_partial_json
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.pydantic_v1 import Field, root_validator
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
@@ -177,21 +177,19 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
result = parser.parse_result([chat_generation])
"""
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
pydantic_schema: Union[Type, Dict[str, Type]]
"""The pydantic schema to parse the output with.
If multiple schemas are provided, then the function name will be used to
determine which schema to use.
"""
context: dict = Field(default_factory=dict)
@root_validator(pre=True)
def validate_schema(cls, values: Dict) -> Dict:
schema = values["pydantic_schema"]
if "args_only" not in values:
values["args_only"] = isinstance(schema, type) and issubclass(
schema, BaseModel
)
elif values["args_only"] and isinstance(schema, Dict):
values["args_only"] = values.get("args_only", not isinstance(schema, Dict))
if values["args_only"] and isinstance(schema, Dict):
raise ValueError(
"If multiple pydantic schemas are provided then args_only should be"
" False."
@@ -201,12 +199,17 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
pydantic_cls = cast(Type, self.pydantic_schema)
args = _result
else:
fn_name = _result["name"]
_args = _result["arguments"]
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore # noqa: E501
return pydantic_args
pydantic_cls = cast(Dict, self.pydantic_schema)[fn_name]
args = _result["arguments"]
if hasattr(pydantic_cls, "model_validate_json"):
return pydantic_cls.model_validate_json(args, context=self.context)
else:
return pydantic_cls.parse_raw(args)
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):

View File

@@ -7,7 +7,7 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.output_parsers.json import parse_partial_json
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import Field
class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
@@ -109,15 +109,20 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
tools: List[Type[BaseModel]]
tools: List[Type]
context: dict = Field(default_factory=dict)
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
parsed_result = super().parse_result(result, partial=partial)
name_dict = {tool.__name__: tool for tool in self.tools}
if self.first_tool_only:
return (
name_dict[parsed_result["type"]](**parsed_result["args"])
if parsed_result
else None
)
return [name_dict[res["type"]](**res["args"]) for res in parsed_result]
return self._parse_tool_call(parsed_result) if parsed_result else None
else:
return [self._parse_tool_call(tool_call) for tool_call in parsed_result]
def _parse_tool_call(self, tool_call: dict) -> Any:
name_dict = {tool.__name__: tool for tool in self.tools}
pydantic_cls = name_dict[tool_call["type"]]
if hasattr(pydantic_cls, "model_validate"):
return pydantic_cls.model_validate(tool_call["args"], context=self.context)
else:
return pydantic_cls.parse_obj(tool_call["args"])

View File

@@ -4,23 +4,29 @@ from typing import Any, List, Type
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.pydantic_v1 import Field, ValidationError
class PydanticOutputParser(JsonOutputParser):
"""Parse an output using a pydantic model."""
pydantic_object: Type[BaseModel]
pydantic_object: Type
"""The pydantic model to parse.
Attention: To avoid potential compatibility issues, it's recommended to use
pydantic <2 or leverage the v1 namespace in pydantic >= 2.
"""
context: dict = Field(default_factory=dict)
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
json_object = super().parse_result(result)
try:
return self.pydantic_object.parse_obj(json_object)
if hasattr(self.pydantic_object, "model_validate"):
return self.pydantic_object.model_validate(
json_object, context=self.context
)
else:
return self.pydantic_object.parse_obj(json_object)
except ValidationError as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
@@ -46,7 +52,7 @@ class PydanticOutputParser(JsonOutputParser):
return "pydantic"
@property
def OutputType(self) -> Type[BaseModel]:
def OutputType(self) -> Type:
"""Return the pydantic model."""
return self.pydantic_object

View File

@@ -0,0 +1,59 @@
import json
from typing import List
import pytest
from langchain_core.messages import AIMessage
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, BaseModel
@pytest.fixture
def generations() -> List[Generation]:
tool_call = {
"id": "",
"type": "function",
"function": {"name": "Foo", "arguments": json.dumps({"bloop": 2})},
}
return [
ChatGeneration(
message=AIMessage("", additional_kwargs={"tool_calls": [tool_call]})
)
]
@pytest.mark.skipif(
_PYDANTIC_MAJOR_VERSION != 1,
reason=f"test only runs on pydantic v1, current version {_PYDANTIC_MAJOR_VERSION}",
)
def test_tools_parser_context_pydantic_v1(generations: List[Generation]) -> None:
class Foo(BaseModel):
bloop: int
parser = PydanticToolsParser(tools=[Foo], context={"baz": "bar"})
assert parser.parse_result(generations) == [Foo(bloop=2)]
@pytest.mark.skipif(
_PYDANTIC_MAJOR_VERSION != 2,
reason=f"test only runs on pydantic v2, current version {_PYDANTIC_MAJOR_VERSION}",
)
def test_tools_parser_context_pydantic_v2(generations: List[Generation]) -> None:
from pydantic import BaseModel as BaseModelV2
from pydantic import ValidationInfo, model_validator
class Foo(BaseModelV2):
bloop: int
@model_validator(mode="before")
def validate_env(cls, values: dict, info: ValidationInfo) -> dict:
context = info.context or {}
if context.get("baz") == "bar":
raise ValueError()
return values
parser = PydanticToolsParser(tools=[Foo], context={"baz": "bar"})
with pytest.raises(ValueError):
parser.parse_result(generations)