Compare commits

...

4 Commits

Author SHA1 Message Date
Tat Dat Duong
f3b1823a49 Fix test 2024-01-08 17:54:35 +01:00
Tat Dat Duong
2fcf3eda5a Mark PydanticToolsParser as non-serializable 2024-01-08 17:24:06 +01:00
Tat Dat Duong
26924e04c7 Mark JsonOutputKeyToolsParser as serializable, fix bug 2024-01-08 17:19:18 +01:00
Tat Dat Duong
61b1e74fdc Enable serialization for JSON output OAI parsers 2024-01-08 17:17:30 +01:00
3 changed files with 55 additions and 1 deletions

View File

@@ -167,6 +167,35 @@ SERIALIZABLE_MAPPING = {
"regex",
"RegexParser",
),
(
"langchain",
"output_parsers",
"openai_functions",
"JsonKeyOutputFunctionsParser",
): (
"langchain",
"output_parsers",
"openai_functions",
"JsonKeyOutputFunctionsParser",
),
("langchain", "output_parsers", "openai_functions", "JsonOutputFunctionsParser"): (
"langchain",
"output_parsers",
"openai_functions",
"JsonOutputFunctionsParser",
),
("langchain", "output_parsers", "openai_tools", "JsonOutputKeyToolsParser"): (
"langchain",
"output_parsers",
"openai_tools",
"JsonOutputKeyToolsParser",
),
("langchain", "output_parsers", "openai_tools", "JsonOutputToolsParser"): (
"langchain",
"output_parsers",
"openai_tools",
"JsonOutputToolsParser",
),
("langchain", "schema", "runnable", "DynamicRunnable"): (
"langchain_core",
"runnables",

View File

@@ -51,6 +51,11 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
args_only: bool = True
"""Whether to only return the arguments to the function call."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
@property
def _type(self) -> str:
return "json_functions"
@@ -129,6 +134,11 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str
"""The name of the key to return."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
res = super().parse_result(result, partial=partial)
if partial and res is None:

View File

@@ -13,6 +13,11 @@ from langchain_core.pydantic_v1 import BaseModel
class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
"""Parse tools from OpenAI response."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
@@ -45,14 +50,24 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
key_name: str
"""The type of tools to return."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
results = super().parse_result(result)
return [res["args"] for res in results if results["type"] == self.key_name]
return [res["args"] for res in results if res["type"] == self.key_name]
class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return False
tools: List[Type[BaseModel]]
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: