mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
Add _type for all parsers (#4189)
Used for serialization. Also add test that recurses through our subclasses to check they have them implemented Would fix https://github.com/hwchase17/langchain/issues/3217 Blocking: https://github.com/mlflow/mlflow/pull/8297 --------- Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
b21d7c138c
commit
812e5f43f5
@ -24,3 +24,7 @@ class ChatOutputParser(AgentOutputParser):
|
|||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise OutputParserException(f"Could not parse LLM output: {text}")
|
raise OutputParserException(f"Could not parse LLM output: {text}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "chat"
|
||||||
|
@ -24,3 +24,7 @@ class ConvoOutputParser(AgentOutputParser):
|
|||||||
action = match.group(1)
|
action = match.group(1)
|
||||||
action_input = match.group(2)
|
action_input = match.group(2)
|
||||||
return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text)
|
return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "conversational"
|
||||||
|
@ -31,3 +31,7 @@ class ConvoOutputParser(AgentOutputParser):
|
|||||||
return AgentFinish({"output": action_input}, text)
|
return AgentFinish({"output": action_input}, text)
|
||||||
else:
|
else:
|
||||||
return AgentAction(action, action_input, text)
|
return AgentAction(action, action_input, text)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "conversational_chat"
|
||||||
|
@ -27,3 +27,7 @@ class MRKLOutputParser(AgentOutputParser):
|
|||||||
action = match.group(1).strip()
|
action = match.group(1).strip()
|
||||||
action_input = match.group(2)
|
action_input = match.group(2)
|
||||||
return AgentAction(action, action_input.strip(" ").strip('"'), text)
|
return AgentAction(action, action_input.strip(" ").strip('"'), text)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "mrkl"
|
||||||
|
@ -24,3 +24,7 @@ class ReActOutputParser(AgentOutputParser):
|
|||||||
return AgentFinish({"output": action_input}, text)
|
return AgentFinish({"output": action_input}, text)
|
||||||
else:
|
else:
|
||||||
return AgentAction(action, action_input, text)
|
return AgentAction(action, action_input, text)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "react"
|
||||||
|
@ -20,3 +20,7 @@ class SelfAskOutputParser(AgentOutputParser):
|
|||||||
if " " == after_colon[0]:
|
if " " == after_colon[0]:
|
||||||
after_colon = after_colon[1:]
|
after_colon = after_colon[1:]
|
||||||
return AgentAction("Intermediate Answer", after_colon, text)
|
return AgentAction("Intermediate Answer", after_colon, text)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "self_ask"
|
||||||
|
@ -40,6 +40,10 @@ class StructuredChatOutputParser(AgentOutputParser):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise OutputParserException(f"Could not parse LLM output: {text}") from e
|
raise OutputParserException(f"Could not parse LLM output: {text}") from e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "structured_chat"
|
||||||
|
|
||||||
|
|
||||||
class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
||||||
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
|
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
|
||||||
@ -76,3 +80,7 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
|||||||
return cls(base_parser=base_parser)
|
return cls(base_parser=base_parser)
|
||||||
else:
|
else:
|
||||||
return cls()
|
return cls()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "structured_chat_with_retries"
|
||||||
|
@ -31,6 +31,10 @@ class APIRequesterOutputParser(BaseOutputParser):
|
|||||||
return f"MESSAGE: {message_match.group(1).strip()}"
|
return f"MESSAGE: {message_match.group(1).strip()}"
|
||||||
return "ERROR making request"
|
return "ERROR making request"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "api_requester"
|
||||||
|
|
||||||
|
|
||||||
class APIRequesterChain(LLMChain):
|
class APIRequesterChain(LLMChain):
|
||||||
"""Get the request parser."""
|
"""Get the request parser."""
|
||||||
|
@ -31,6 +31,10 @@ class APIResponderOutputParser(BaseOutputParser):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"No response found in output: {llm_output}.")
|
raise ValueError(f"No response found in output: {llm_output}.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "api_responder"
|
||||||
|
|
||||||
|
|
||||||
class APIResponderChain(LLMChain):
|
class APIResponderChain(LLMChain):
|
||||||
"""Get the response parser."""
|
"""Get the response parser."""
|
||||||
|
@ -52,6 +52,10 @@ class BashOutputParser(BaseOutputParser):
|
|||||||
|
|
||||||
return code_blocks
|
return code_blocks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "bash"
|
||||||
|
|
||||||
|
|
||||||
PROMPT = PromptTemplate(
|
PROMPT = PromptTemplate(
|
||||||
input_variables=["question"],
|
input_variables=["question"],
|
||||||
|
@ -45,4 +45,4 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _type(self) -> str:
|
def _type(self) -> str:
|
||||||
return self.parser._type
|
return "output_fixing"
|
||||||
|
@ -78,7 +78,7 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _type(self) -> str:
|
def _type(self) -> str:
|
||||||
return self.parser._type
|
return "retry"
|
||||||
|
|
||||||
|
|
||||||
class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||||
@ -122,3 +122,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
return self.parser.get_format_instructions()
|
return self.parser.get_format_instructions()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "retry_with_error"
|
||||||
|
@ -348,7 +348,10 @@ class BaseOutputParser(BaseModel, ABC, Generic[T]):
|
|||||||
@property
|
@property
|
||||||
def _type(self) -> str:
|
def _type(self) -> str:
|
||||||
"""Return the type key."""
|
"""Return the type key."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError(
|
||||||
|
f"_type property is not implemented in class {self.__class__.__name__}."
|
||||||
|
" This is required for serialization."
|
||||||
|
)
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
"""Return dictionary representation of output parser."""
|
"""Return dictionary representation of output parser."""
|
||||||
|
47
tests/unit_tests/output_parsers/test_base_output_parser.py
Normal file
47
tests/unit_tests/output_parsers/test_base_output_parser.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
"""Test the BaseOutputParser class and its sub-classes."""
|
||||||
|
from abc import ABC
|
||||||
|
from typing import List, Optional, Set, Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.schema import BaseOutputParser
|
||||||
|
|
||||||
|
|
||||||
|
def non_abstract_subclasses(
|
||||||
|
cls: Type[ABC], to_skip: Optional[Set] = None
|
||||||
|
) -> List[Type]:
|
||||||
|
"""Recursively find all non-abstract subclasses of a class."""
|
||||||
|
_to_skip = to_skip or set()
|
||||||
|
subclasses = []
|
||||||
|
for subclass in cls.__subclasses__():
|
||||||
|
if not getattr(subclass, "__abstractmethods__", None):
|
||||||
|
if subclass.__name__ not in _to_skip:
|
||||||
|
subclasses.append(subclass)
|
||||||
|
subclasses.extend(non_abstract_subclasses(subclass, to_skip=_to_skip))
|
||||||
|
return subclasses
|
||||||
|
|
||||||
|
|
||||||
|
_PARSERS_TO_SKIP = {"FakeOutputParser", "BaseOutputParser"}
|
||||||
|
_NON_ABSTRACT_PARSERS = non_abstract_subclasses(
|
||||||
|
BaseOutputParser, to_skip=_PARSERS_TO_SKIP
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("cls", _NON_ABSTRACT_PARSERS)
|
||||||
|
def test_subclass_implements_type(cls: Type[BaseOutputParser]) -> None:
|
||||||
|
try:
|
||||||
|
cls._type
|
||||||
|
except NotImplementedError:
|
||||||
|
pytest.fail(f"_type property is not implemented in class {cls.__name__}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_subclasses_implement_unique_type() -> None:
|
||||||
|
types = []
|
||||||
|
for cls in _NON_ABSTRACT_PARSERS:
|
||||||
|
try:
|
||||||
|
types.append(cls._type)
|
||||||
|
except NotImplementedError:
|
||||||
|
# This is handled in the previous test
|
||||||
|
pass
|
||||||
|
dups = set([t for t in types if types.count(t) > 1])
|
||||||
|
assert not dups, f"Duplicate types: {dups}"
|
Loading…
Reference in New Issue
Block a user