Compare commits

...

7 Commits

Author SHA1 Message Date
vowelparrot
0a050ee8c0 Merge branch 'master' into vwp/parser__type 2023-04-28 00:54:03 -07:00
vowelparrot
c808602fc7 undo other pr 2023-04-21 17:13:12 -07:00
vowelparrot
695345ae29 update tests 2023-04-21 17:08:00 -07:00
vowelparrot
fa47676f13 Add _type properties to output parsers
Used for serialization. Also add test that recurses through
our subclasses to check they have them implemented
2023-04-21 17:03:52 -07:00
vowelparrot
52e3944b77 Merge branch 'master' into vwp/3297 2023-04-21 15:50:18 -07:00
vowelparrot
d39c3ca67f Merge branch 'master' into vwp/3297 2023-04-21 15:34:04 -07:00
vowelparrot
49a0f38c99 Structured Tool Bugfixes
- Proactively raise error if a tool subclasses BaseTool, defines its
own schema, but fails to add the type-hints
- fix the auto-inferred schema of the decorator to strip the
unneeded virtual kwargs from the schema dict
2023-04-21 15:05:41 -07:00
12 changed files with 85 additions and 3 deletions

View File

@@ -24,3 +24,7 @@ class ChatOutputParser(AgentOutputParser):
except Exception:
raise OutputParserException(f"Could not parse LLM output: {text}")
@property
def _type(self) -> str:
return "chat"

View File

@@ -24,3 +24,7 @@ class ConvoOutputParser(AgentOutputParser):
action = match.group(1)
action_input = match.group(2)
return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text)
@property
def _type(self) -> str:
return "conversational"

View File

@@ -31,3 +31,7 @@ class ConvoOutputParser(AgentOutputParser):
return AgentFinish({"output": action_input}, text)
else:
return AgentAction(action, action_input, text)
@property
def _type(self) -> str:
return "conversational_chat"

View File

@@ -27,3 +27,7 @@ class MRKLOutputParser(AgentOutputParser):
action = match.group(1).strip()
action_input = match.group(2)
return AgentAction(action, action_input.strip(" ").strip('"'), text)
@property
def _type(self) -> str:
return "mrkl"

View File

@@ -24,3 +24,7 @@ class ReActOutputParser(AgentOutputParser):
return AgentFinish({"output": action_input}, text)
else:
return AgentAction(action, action_input, text)
@property
def _type(self) -> str:
return "react"

View File

@@ -20,3 +20,7 @@ class SelfAskOutputParser(AgentOutputParser):
if " " == after_colon[0]:
after_colon = after_colon[1:]
return AgentAction("Intermediate Answer", after_colon, text)
@property
def _type(self) -> str:
return "self_ask"

View File

@@ -30,6 +30,10 @@ class APIRequesterOutputParser(BaseOutputParser):
return f"MESSAGE: {message_match.group(1).strip()}"
return "ERROR making request"
@property
def _type(self) -> str:
return "api_requester"
class APIRequesterChain(LLMChain):
"""Get the request parser."""

View File

@@ -30,6 +30,10 @@ class APIResponderOutputParser(BaseOutputParser):
else:
raise ValueError(f"No response found in output: {llm_output}.")
@property
def _type(self) -> str:
return "api_responder"
class APIResponderChain(LLMChain):
"""Get the response parser."""

View File

@@ -44,4 +44,4 @@ class OutputFixingParser(BaseOutputParser[T]):
@property
def _type(self) -> str:
return self.parser._type
return f"output_fixing_{self.parser._type}"

View File

@@ -78,7 +78,7 @@ class RetryOutputParser(BaseOutputParser[T]):
@property
def _type(self) -> str:
return self.parser._type
return f"retry_{self.parser._type}"
class RetryWithErrorOutputParser(BaseOutputParser[T]):
@@ -122,3 +122,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()
@property
def _type(self) -> str:
return "retry_with_error"

View File

@@ -383,7 +383,10 @@ class BaseOutputParser(BaseModel, ABC, Generic[T]):
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError
raise NotImplementedError(
f"_type property is not implemented in class {self.__class__.__name__}."
" This is required for serialization."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""

View File

@@ -0,0 +1,43 @@
"""Test the BaseOutputParser class and its sub-classes."""
from abc import ABC
from typing import List, Type
from unittest.mock import MagicMock
import pytest
from langchain.schema import BaseOutputParser
_PARSERS_TO_SKIP = {"FakeOutputParser", "BaseOutputParser"}
def non_abstract_subclasses(cls: Type[ABC]) -> List[Type]:
"""Recursively find all non-abstract subclasses of a class."""
subclasses = []
for subclass in cls.__subclasses__():
if not getattr(subclass, "__abstractmethods__", None):
if subclass.__name__ not in _PARSERS_TO_SKIP:
subclasses.append(subclass)
subclasses.extend(non_abstract_subclasses(subclass))
return subclasses
@pytest.mark.parametrize("cls", non_abstract_subclasses(BaseOutputParser))
def test_all_subclasses_implement_type(cls: Type[BaseOutputParser]) -> None:
try:
# Most parsers just return a string. MagicMock lets
# the parsers that wrap another parsers slide by
cls._type.fget(MagicMock()) # type: ignore
except NotImplementedError:
pytest.fail(f"_type property is not implemented in class {cls.__name__}")
def test_all_subclasses_implement_unique_type() -> None:
types = []
for cls in non_abstract_subclasses(BaseOutputParser):
try:
types.append(cls._type.fget(MagicMock()))
except NotImplementedError:
# This is handled in the previous test
pass
dups = set([t for t in types if types.count(t) > 1])
assert not dups, f"Duplicate types: {dups}"