mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-18 09:01:03 +00:00
core[patch], langchain[patch], templates: move openai functions parsers to core (#18060)

This commit is contained in:
parent
96bff0ed5d
commit
767523f364
220
libs/core/langchain_core/output_parsers/openai_functions.py
Normal file
220
libs/core/langchain_core/output_parsers/openai_functions.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
import copy
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
|
import jsonpatch # type: ignore[import]
|
||||||
|
|
||||||
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
from langchain_core.output_parsers import (
|
||||||
|
BaseCumulativeTransformOutputParser,
|
||||||
|
BaseGenerationOutputParser,
|
||||||
|
)
|
||||||
|
from langchain_core.output_parsers.json import parse_partial_json
|
||||||
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||||
|
|
||||||
|
|
||||||
|
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||||
|
"""Parse an output that is one of sets of values."""
|
||||||
|
|
||||||
|
args_only: bool = True
|
||||||
|
"""Whether to only return the arguments to the function call."""
|
||||||
|
|
||||||
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||||
|
generation = result[0]
|
||||||
|
if not isinstance(generation, ChatGeneration):
|
||||||
|
raise OutputParserException(
|
||||||
|
"This output parser can only be used with a chat generation."
|
||||||
|
)
|
||||||
|
message = generation.message
|
||||||
|
try:
|
||||||
|
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
|
||||||
|
except KeyError as exc:
|
||||||
|
raise OutputParserException(f"Could not parse function call: {exc}")
|
||||||
|
|
||||||
|
if self.args_only:
|
||||||
|
return func_call["arguments"]
|
||||||
|
return func_call
|
||||||
|
|
||||||
|
|
||||||
|
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||||
|
"""Parse an output as the Json object."""
|
||||||
|
|
||||||
|
strict: bool = False
|
||||||
|
"""Whether to allow non-JSON-compliant strings.
|
||||||
|
|
||||||
|
See: https://docs.python.org/3/library/json.html#encoders-and-decoders
|
||||||
|
|
||||||
|
Useful when the parsed output may include unicode characters or new lines.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args_only: bool = True
|
||||||
|
"""Whether to only return the arguments to the function call."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "json_functions"
|
||||||
|
|
||||||
|
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||||
|
return jsonpatch.make_patch(prev, next).patch
|
||||||
|
|
||||||
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||||
|
if len(result) != 1:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Expected exactly one result, but got {len(result)}"
|
||||||
|
)
|
||||||
|
generation = result[0]
|
||||||
|
if not isinstance(generation, ChatGeneration):
|
||||||
|
raise OutputParserException(
|
||||||
|
"This output parser can only be used with a chat generation."
|
||||||
|
)
|
||||||
|
message = generation.message
|
||||||
|
try:
|
||||||
|
function_call = message.additional_kwargs["function_call"]
|
||||||
|
except KeyError as exc:
|
||||||
|
if partial:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise OutputParserException(f"Could not parse function call: {exc}")
|
||||||
|
try:
|
||||||
|
if partial:
|
||||||
|
try:
|
||||||
|
if self.args_only:
|
||||||
|
return parse_partial_json(
|
||||||
|
function_call["arguments"], strict=self.strict
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
**function_call,
|
||||||
|
"arguments": parse_partial_json(
|
||||||
|
function_call["arguments"], strict=self.strict
|
||||||
|
),
|
||||||
|
}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
if self.args_only:
|
||||||
|
try:
|
||||||
|
return json.loads(
|
||||||
|
function_call["arguments"], strict=self.strict
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, TypeError) as exc:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Could not parse function call data: {exc}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
**function_call,
|
||||||
|
"arguments": json.loads(
|
||||||
|
function_call["arguments"], strict=self.strict
|
||||||
|
),
|
||||||
|
}
|
||||||
|
except (json.JSONDecodeError, TypeError) as exc:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Could not parse function call data: {exc}"
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# This method would be called by the default implementation of `parse_result`
|
||||||
|
# but we're overriding that method so it's not needed.
|
||||||
|
def parse(self, text: str) -> Any:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||||
|
"""Parse an output as the element of the Json object."""
|
||||||
|
|
||||||
|
key_name: str
|
||||||
|
"""The name of the key to return."""
|
||||||
|
|
||||||
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||||
|
res = super().parse_result(result, partial=partial)
|
||||||
|
if partial and res is None:
|
||||||
|
return None
|
||||||
|
return res.get(self.key_name) if partial else res[self.key_name]
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||||
|
"""Parse an output as a pydantic object.
|
||||||
|
|
||||||
|
This parser is used to parse the output of a ChatModel that uses
|
||||||
|
OpenAI function format to invoke functions.
|
||||||
|
|
||||||
|
The parser extracts the function call invocation and matches
|
||||||
|
them to the pydantic schema provided.
|
||||||
|
|
||||||
|
An exception will be raised if the function call does not match
|
||||||
|
the provided schema.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
... code-block:: python
|
||||||
|
|
||||||
|
message = AIMessage(
|
||||||
|
content="This is a test message",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": {
|
||||||
|
"name": "cookie",
|
||||||
|
"arguments": json.dumps({"name": "value", "age": 10}),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
chat_generation = ChatGeneration(message=message)
|
||||||
|
|
||||||
|
class Cookie(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
|
||||||
|
class Dog(BaseModel):
|
||||||
|
species: str
|
||||||
|
|
||||||
|
# Full output
|
||||||
|
parser = PydanticOutputFunctionsParser(
|
||||||
|
pydantic_schema={"cookie": Cookie, "dog": Dog}
|
||||||
|
)
|
||||||
|
result = parser.parse_result([chat_generation])
|
||||||
|
"""
|
||||||
|
|
||||||
|
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
|
||||||
|
"""The pydantic schema to parse the output with.
|
||||||
|
|
||||||
|
If multiple schemas are provided, then the function name will be used to
|
||||||
|
determine which schema to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def validate_schema(cls, values: Dict) -> Dict:
|
||||||
|
schema = values["pydantic_schema"]
|
||||||
|
if "args_only" not in values:
|
||||||
|
values["args_only"] = isinstance(schema, type) and issubclass(
|
||||||
|
schema, BaseModel
|
||||||
|
)
|
||||||
|
elif values["args_only"] and isinstance(schema, Dict):
|
||||||
|
raise ValueError(
|
||||||
|
"If multiple pydantic schemas are provided then args_only should be"
|
||||||
|
" False."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||||
|
_result = super().parse_result(result)
|
||||||
|
if self.args_only:
|
||||||
|
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
||||||
|
else:
|
||||||
|
fn_name = _result["name"]
|
||||||
|
_args = _result["arguments"]
|
||||||
|
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore # noqa: E501
|
||||||
|
return pydantic_args
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
||||||
|
"""Parse an output as an attribute of a pydantic object."""
|
||||||
|
|
||||||
|
attr_name: str
|
||||||
|
"""The name of the attribute to return."""
|
||||||
|
|
||||||
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||||
|
result = super().parse_result(result)
|
||||||
|
return getattr(result, self.attr_name)
|
@ -2,15 +2,15 @@ import json
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from langchain_core.outputs import ChatGeneration
|
from langchain_core.output_parsers.openai_functions import (
|
||||||
|
|
||||||
from langchain.output_parsers.openai_functions import (
|
|
||||||
JsonOutputFunctionsParser,
|
JsonOutputFunctionsParser,
|
||||||
PydanticOutputFunctionsParser,
|
PydanticOutputFunctionsParser,
|
||||||
)
|
)
|
||||||
from langchain.pydantic_v1 import BaseModel
|
from langchain_core.outputs import ChatGeneration
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
|
|
||||||
def test_json_output_function_parser() -> None:
|
def test_json_output_function_parser() -> None:
|
@ -14,6 +14,9 @@ from langchain_core.language_models import BaseLanguageModel
|
|||||||
from langchain_core.output_parsers import (
|
from langchain_core.output_parsers import (
|
||||||
BaseLLMOutputParser,
|
BaseLLMOutputParser,
|
||||||
)
|
)
|
||||||
|
from langchain_core.output_parsers.openai_functions import (
|
||||||
|
PydanticAttrOutputFunctionsParser,
|
||||||
|
)
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.utils.function_calling import (
|
from langchain_core.utils.function_calling import (
|
||||||
@ -27,9 +30,6 @@ from langchain.chains.structured_output.base import (
|
|||||||
create_structured_output_runnable,
|
create_structured_output_runnable,
|
||||||
get_openai_output_parser,
|
get_openai_output_parser,
|
||||||
)
|
)
|
||||||
from langchain.output_parsers.openai_functions import (
|
|
||||||
PydanticAttrOutputFunctionsParser,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_openai_output_parser",
|
"get_openai_output_parser",
|
||||||
|
@ -2,14 +2,12 @@ from typing import Iterator, List
|
|||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.output_parsers.openai_functions import PydanticOutputFunctionsParser
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.openai_functions.utils import get_llm_kwargs
|
from langchain.chains.openai_functions.utils import get_llm_kwargs
|
||||||
from langchain.output_parsers.openai_functions import (
|
|
||||||
PydanticOutputFunctionsParser,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FactWithEvidence(BaseModel):
|
class FactWithEvidence(BaseModel):
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.output_parsers.openai_functions import (
|
||||||
|
JsonKeyOutputFunctionsParser,
|
||||||
|
PydanticAttrOutputFunctionsParser,
|
||||||
|
)
|
||||||
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
|
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
@ -11,10 +15,6 @@ from langchain.chains.openai_functions.utils import (
|
|||||||
_resolve_schema_references,
|
_resolve_schema_references,
|
||||||
get_llm_kwargs,
|
get_llm_kwargs,
|
||||||
)
|
)
|
||||||
from langchain.output_parsers.openai_functions import (
|
|
||||||
JsonKeyOutputFunctionsParser,
|
|
||||||
PydanticAttrOutputFunctionsParser,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_extraction_function(entity_schema: dict) -> dict:
|
def _get_extraction_function(entity_schema: dict) -> dict:
|
||||||
|
@ -10,6 +10,7 @@ from langchain_community.chat_models import ChatOpenAI
|
|||||||
from langchain_community.utilities.openapi import OpenAPISpec
|
from langchain_community.utilities.openapi import OpenAPISpec
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||||
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
|
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
|
||||||
from langchain_core.utils.input import get_colored_text
|
from langchain_core.utils.input import get_colored_text
|
||||||
from requests import Response
|
from requests import Response
|
||||||
@ -17,7 +18,6 @@ from requests import Response
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.sequential import SequentialChain
|
from langchain.chains.sequential import SequentialChain
|
||||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
|
||||||
from langchain.tools import APIOperation
|
from langchain.tools import APIOperation
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -3,16 +3,16 @@ from typing import Any, List, Optional, Type, Union
|
|||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from langchain_core.output_parsers import BaseLLMOutputParser
|
from langchain_core.output_parsers import BaseLLMOutputParser
|
||||||
|
from langchain_core.output_parsers.openai_functions import (
|
||||||
|
OutputFunctionsParser,
|
||||||
|
PydanticOutputFunctionsParser,
|
||||||
|
)
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.openai_functions.utils import get_llm_kwargs
|
from langchain.chains.openai_functions.utils import get_llm_kwargs
|
||||||
from langchain.output_parsers.openai_functions import (
|
|
||||||
OutputFunctionsParser,
|
|
||||||
PydanticOutputFunctionsParser,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AnswerWithSources(BaseModel):
|
class AnswerWithSources(BaseModel):
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.output_parsers.openai_functions import (
|
||||||
|
JsonOutputFunctionsParser,
|
||||||
|
PydanticOutputFunctionsParser,
|
||||||
|
)
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.openai_functions.utils import _convert_schema, get_llm_kwargs
|
from langchain.chains.openai_functions.utils import _convert_schema, get_llm_kwargs
|
||||||
from langchain.output_parsers.openai_functions import (
|
|
||||||
JsonOutputFunctionsParser,
|
|
||||||
PydanticOutputFunctionsParser,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tagging_function(schema: dict) -> dict:
|
def _get_tagging_function(schema: dict) -> dict:
|
||||||
|
@ -6,6 +6,11 @@ from langchain_core.output_parsers import (
|
|||||||
BaseOutputParser,
|
BaseOutputParser,
|
||||||
JsonOutputParser,
|
JsonOutputParser,
|
||||||
)
|
)
|
||||||
|
from langchain_core.output_parsers.openai_functions import (
|
||||||
|
JsonOutputFunctionsParser,
|
||||||
|
PydanticAttrOutputFunctionsParser,
|
||||||
|
PydanticOutputFunctionsParser,
|
||||||
|
)
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
@ -19,11 +24,6 @@ from langchain.output_parsers import (
|
|||||||
PydanticOutputParser,
|
PydanticOutputParser,
|
||||||
PydanticToolsParser,
|
PydanticToolsParser,
|
||||||
)
|
)
|
||||||
from langchain.output_parsers.openai_functions import (
|
|
||||||
JsonOutputFunctionsParser,
|
|
||||||
PydanticAttrOutputFunctionsParser,
|
|
||||||
PydanticOutputFunctionsParser,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_openai_fn_runnable(
|
def create_openai_fn_runnable(
|
||||||
|
@ -1,219 +1,13 @@
|
|||||||
import copy
|
from langchain_core.output_parsers.openai_functions import (
|
||||||
import json
|
JsonKeyOutputFunctionsParser,
|
||||||
from typing import Any, Dict, List, Optional, Type, Union
|
JsonOutputFunctionsParser,
|
||||||
|
PydanticAttrOutputFunctionsParser,
|
||||||
import jsonpatch
|
PydanticOutputFunctionsParser,
|
||||||
from langchain_core.exceptions import OutputParserException
|
|
||||||
from langchain_core.output_parsers import (
|
|
||||||
BaseCumulativeTransformOutputParser,
|
|
||||||
BaseGenerationOutputParser,
|
|
||||||
)
|
)
|
||||||
from langchain_core.output_parsers.json import parse_partial_json
|
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
"PydanticOutputFunctionsParser",
|
||||||
"""Parse an output that is one of sets of values."""
|
"PydanticAttrOutputFunctionsParser",
|
||||||
|
"JsonOutputFunctionsParser",
|
||||||
args_only: bool = True
|
"JsonKeyOutputFunctionsParser",
|
||||||
"""Whether to only return the arguments to the function call."""
|
]
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
|
||||||
generation = result[0]
|
|
||||||
if not isinstance(generation, ChatGeneration):
|
|
||||||
raise OutputParserException(
|
|
||||||
"This output parser can only be used with a chat generation."
|
|
||||||
)
|
|
||||||
message = generation.message
|
|
||||||
try:
|
|
||||||
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
|
|
||||||
except KeyError as exc:
|
|
||||||
raise OutputParserException(f"Could not parse function call: {exc}")
|
|
||||||
|
|
||||||
if self.args_only:
|
|
||||||
return func_call["arguments"]
|
|
||||||
return func_call
|
|
||||||
|
|
||||||
|
|
||||||
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|
||||||
"""Parse an output as the Json object."""
|
|
||||||
|
|
||||||
strict: bool = False
|
|
||||||
"""Whether to allow non-JSON-compliant strings.
|
|
||||||
|
|
||||||
See: https://docs.python.org/3/library/json.html#encoders-and-decoders
|
|
||||||
|
|
||||||
Useful when the parsed output may include unicode characters or new lines.
|
|
||||||
"""
|
|
||||||
|
|
||||||
args_only: bool = True
|
|
||||||
"""Whether to only return the arguments to the function call."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _type(self) -> str:
|
|
||||||
return "json_functions"
|
|
||||||
|
|
||||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
|
||||||
return jsonpatch.make_patch(prev, next).patch
|
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
|
||||||
if len(result) != 1:
|
|
||||||
raise OutputParserException(
|
|
||||||
f"Expected exactly one result, but got {len(result)}"
|
|
||||||
)
|
|
||||||
generation = result[0]
|
|
||||||
if not isinstance(generation, ChatGeneration):
|
|
||||||
raise OutputParserException(
|
|
||||||
"This output parser can only be used with a chat generation."
|
|
||||||
)
|
|
||||||
message = generation.message
|
|
||||||
try:
|
|
||||||
function_call = message.additional_kwargs["function_call"]
|
|
||||||
except KeyError as exc:
|
|
||||||
if partial:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
raise OutputParserException(f"Could not parse function call: {exc}")
|
|
||||||
try:
|
|
||||||
if partial:
|
|
||||||
try:
|
|
||||||
if self.args_only:
|
|
||||||
return parse_partial_json(
|
|
||||||
function_call["arguments"], strict=self.strict
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
**function_call,
|
|
||||||
"arguments": parse_partial_json(
|
|
||||||
function_call["arguments"], strict=self.strict
|
|
||||||
),
|
|
||||||
}
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if self.args_only:
|
|
||||||
try:
|
|
||||||
return json.loads(
|
|
||||||
function_call["arguments"], strict=self.strict
|
|
||||||
)
|
|
||||||
except (json.JSONDecodeError, TypeError) as exc:
|
|
||||||
raise OutputParserException(
|
|
||||||
f"Could not parse function call data: {exc}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
return {
|
|
||||||
**function_call,
|
|
||||||
"arguments": json.loads(
|
|
||||||
function_call["arguments"], strict=self.strict
|
|
||||||
),
|
|
||||||
}
|
|
||||||
except (json.JSONDecodeError, TypeError) as exc:
|
|
||||||
raise OutputParserException(
|
|
||||||
f"Could not parse function call data: {exc}"
|
|
||||||
)
|
|
||||||
except KeyError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# This method would be called by the default implementation of `parse_result`
|
|
||||||
# but we're overriding that method so it's not needed.
|
|
||||||
def parse(self, text: str) -> Any:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
|
||||||
"""Parse an output as the element of the Json object."""
|
|
||||||
|
|
||||||
key_name: str
|
|
||||||
"""The name of the key to return."""
|
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
|
||||||
res = super().parse_result(result, partial=partial)
|
|
||||||
if partial and res is None:
|
|
||||||
return None
|
|
||||||
return res.get(self.key_name) if partial else res[self.key_name]
|
|
||||||
|
|
||||||
|
|
||||||
class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
|
||||||
"""Parse an output as a pydantic object.
|
|
||||||
|
|
||||||
This parser is used to parse the output of a ChatModel that uses
|
|
||||||
OpenAI function format to invoke functions.
|
|
||||||
|
|
||||||
The parser extracts the function call invocation and matches
|
|
||||||
them to the pydantic schema provided.
|
|
||||||
|
|
||||||
An exception will be raised if the function call does not match
|
|
||||||
the provided schema.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
... code-block:: python
|
|
||||||
|
|
||||||
message = AIMessage(
|
|
||||||
content="This is a test message",
|
|
||||||
additional_kwargs={
|
|
||||||
"function_call": {
|
|
||||||
"name": "cookie",
|
|
||||||
"arguments": json.dumps({"name": "value", "age": 10}),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
chat_generation = ChatGeneration(message=message)
|
|
||||||
|
|
||||||
class Cookie(BaseModel):
|
|
||||||
name: str
|
|
||||||
age: int
|
|
||||||
|
|
||||||
class Dog(BaseModel):
|
|
||||||
species: str
|
|
||||||
|
|
||||||
# Full output
|
|
||||||
parser = PydanticOutputFunctionsParser(
|
|
||||||
pydantic_schema={"cookie": Cookie, "dog": Dog}
|
|
||||||
)
|
|
||||||
result = parser.parse_result([chat_generation])
|
|
||||||
"""
|
|
||||||
|
|
||||||
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
|
|
||||||
"""The pydantic schema to parse the output with.
|
|
||||||
|
|
||||||
If multiple schemas are provided, then the function name will be used to
|
|
||||||
determine which schema to use.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def validate_schema(cls, values: Dict) -> Dict:
|
|
||||||
schema = values["pydantic_schema"]
|
|
||||||
if "args_only" not in values:
|
|
||||||
values["args_only"] = isinstance(schema, type) and issubclass(
|
|
||||||
schema, BaseModel
|
|
||||||
)
|
|
||||||
elif values["args_only"] and isinstance(schema, Dict):
|
|
||||||
raise ValueError(
|
|
||||||
"If multiple pydantic schemas are provided then args_only should be"
|
|
||||||
" False."
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
|
||||||
_result = super().parse_result(result)
|
|
||||||
if self.args_only:
|
|
||||||
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
|
||||||
else:
|
|
||||||
fn_name = _result["name"]
|
|
||||||
_args = _result["arguments"]
|
|
||||||
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore # noqa: E501
|
|
||||||
return pydantic_args
|
|
||||||
|
|
||||||
|
|
||||||
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
|
||||||
"""Parse an output as an attribute of a pydantic object."""
|
|
||||||
|
|
||||||
attr_name: str
|
|
||||||
"""The name of the attribute to return."""
|
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
|
||||||
result = super().parse_result(result)
|
|
||||||
return getattr(result, self.attr_name)
|
|
||||||
|
@ -2,12 +2,11 @@ from operator import itemgetter
|
|||||||
from typing import Any, Callable, List, Mapping, Optional, Union
|
from typing import Any, Callable, List, Mapping, Optional, Union
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||||
from langchain_core.runnables import RouterRunnable, Runnable
|
from langchain_core.runnables import RouterRunnable, Runnable
|
||||||
from langchain_core.runnables.base import RunnableBindingBase
|
from langchain_core.runnables.base import RunnableBindingBase
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIFunction(TypedDict):
|
class OpenAIFunction(TypedDict):
|
||||||
"""A function description for ChatOpenAI"""
|
"""A function description for ChatOpenAI"""
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from typing import Any, AsyncIterator, Iterator
|
from typing import Any, AsyncIterator, Iterator
|
||||||
|
|
||||||
from langchain_core.messages import AIMessageChunk
|
from langchain_core.messages import AIMessageChunk
|
||||||
|
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
|
||||||
|
|
||||||
GOOD_JSON = """```json
|
GOOD_JSON = """```json
|
||||||
{
|
{
|
||||||
|
@ -5,7 +5,6 @@ from typing import List, Optional
|
|||||||
from langchain import hub
|
from langchain import hub
|
||||||
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
|
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -14,6 +13,7 @@ from langchain.schema import (
|
|||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
|
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser
|
|
||||||
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
|
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
|
||||||
|
from langchain_core.output_parsers.openai_functions import JsonKeyOutputFunctionsParser
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_experimental.llms.anthropic_functions import AnthropicFunctions
|
from langchain_experimental.llms.anthropic_functions import AnthropicFunctions
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from langchain.output_parsers.openai_functions import PydanticAttrOutputFunctionsParser
|
|
||||||
from langchain.retrievers import (
|
from langchain.retrievers import (
|
||||||
ArxivRetriever,
|
ArxivRetriever,
|
||||||
KayAiRetriever,
|
KayAiRetriever,
|
||||||
@ -11,6 +10,9 @@ from langchain.retrievers import (
|
|||||||
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
|
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.output_parsers.openai_functions import (
|
||||||
|
PydanticAttrOutputFunctionsParser,
|
||||||
|
)
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
|
Loading…
Reference in New Issue
Block a user