mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
Add _type for all parsers (#4189)
Used for serialization. Also add test that recurses through our subclasses to check they have them implemented Would fix https://github.com/hwchase17/langchain/issues/3217 Blocking: https://github.com/mlflow/mlflow/pull/8297 --------- Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
b21d7c138c
commit
812e5f43f5
@ -24,3 +24,7 @@ class ChatOutputParser(AgentOutputParser):
|
||||
|
||||
except Exception:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}")
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "chat"
|
||||
|
@ -24,3 +24,7 @@ class ConvoOutputParser(AgentOutputParser):
|
||||
action = match.group(1)
|
||||
action_input = match.group(2)
|
||||
return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "conversational"
|
||||
|
@ -31,3 +31,7 @@ class ConvoOutputParser(AgentOutputParser):
|
||||
return AgentFinish({"output": action_input}, text)
|
||||
else:
|
||||
return AgentAction(action, action_input, text)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "conversational_chat"
|
||||
|
@ -27,3 +27,7 @@ class MRKLOutputParser(AgentOutputParser):
|
||||
action = match.group(1).strip()
|
||||
action_input = match.group(2)
|
||||
return AgentAction(action, action_input.strip(" ").strip('"'), text)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "mrkl"
|
||||
|
@ -24,3 +24,7 @@ class ReActOutputParser(AgentOutputParser):
|
||||
return AgentFinish({"output": action_input}, text)
|
||||
else:
|
||||
return AgentAction(action, action_input, text)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "react"
|
||||
|
@ -20,3 +20,7 @@ class SelfAskOutputParser(AgentOutputParser):
|
||||
if " " == after_colon[0]:
|
||||
after_colon = after_colon[1:]
|
||||
return AgentAction("Intermediate Answer", after_colon, text)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "self_ask"
|
||||
|
@ -40,6 +40,10 @@ class StructuredChatOutputParser(AgentOutputParser):
|
||||
except Exception as e:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}") from e
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "structured_chat"
|
||||
|
||||
|
||||
class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
||||
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
|
||||
@ -76,3 +80,7 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
||||
return cls(base_parser=base_parser)
|
||||
else:
|
||||
return cls()
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "structured_chat_with_retries"
|
||||
|
@ -31,6 +31,10 @@ class APIRequesterOutputParser(BaseOutputParser):
|
||||
return f"MESSAGE: {message_match.group(1).strip()}"
|
||||
return "ERROR making request"
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "api_requester"
|
||||
|
||||
|
||||
class APIRequesterChain(LLMChain):
|
||||
"""Get the request parser."""
|
||||
|
@ -31,6 +31,10 @@ class APIResponderOutputParser(BaseOutputParser):
|
||||
else:
|
||||
raise ValueError(f"No response found in output: {llm_output}.")
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "api_responder"
|
||||
|
||||
|
||||
class APIResponderChain(LLMChain):
|
||||
"""Get the response parser."""
|
||||
|
@ -52,6 +52,10 @@ class BashOutputParser(BaseOutputParser):
|
||||
|
||||
return code_blocks
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "bash"
|
||||
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
|
@ -45,4 +45,4 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return self.parser._type
|
||||
return "output_fixing"
|
||||
|
@ -78,7 +78,7 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return self.parser._type
|
||||
return "retry"
|
||||
|
||||
|
||||
class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
@ -122,3 +122,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return self.parser.get_format_instructions()
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "retry_with_error"
|
||||
|
@ -227,25 +227,25 @@ class BaseChatMessageHistory(ABC):
|
||||
class FileChatMessageHistory(BaseChatMessageHistory):
|
||||
storage_path: str
|
||||
session_id: str
|
||||
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
|
||||
messages = json.loads(f.read())
|
||||
return messages_from_dict(messages)
|
||||
|
||||
return messages_from_dict(messages)
|
||||
|
||||
def add_user_message(self, message: str):
|
||||
message_ = HumanMessage(content=message)
|
||||
messages = self.messages.append(_message_to_dict(_message))
|
||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||
json.dump(f, messages)
|
||||
|
||||
|
||||
def add_ai_message(self, message: str):
|
||||
message_ = AIMessage(content=message)
|
||||
messages = self.messages.append(_message_to_dict(_message))
|
||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||
json.dump(f, messages)
|
||||
|
||||
|
||||
def clear(self):
|
||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||
f.write("[]")
|
||||
@ -348,7 +348,10 @@ class BaseOutputParser(BaseModel, ABC, Generic[T]):
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the type key."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(
|
||||
f"_type property is not implemented in class {self.__class__.__name__}."
|
||||
" This is required for serialization."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
|
47
tests/unit_tests/output_parsers/test_base_output_parser.py
Normal file
47
tests/unit_tests/output_parsers/test_base_output_parser.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""Test the BaseOutputParser class and its sub-classes."""
|
||||
from abc import ABC
|
||||
from typing import List, Optional, Set, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
def non_abstract_subclasses(
|
||||
cls: Type[ABC], to_skip: Optional[Set] = None
|
||||
) -> List[Type]:
|
||||
"""Recursively find all non-abstract subclasses of a class."""
|
||||
_to_skip = to_skip or set()
|
||||
subclasses = []
|
||||
for subclass in cls.__subclasses__():
|
||||
if not getattr(subclass, "__abstractmethods__", None):
|
||||
if subclass.__name__ not in _to_skip:
|
||||
subclasses.append(subclass)
|
||||
subclasses.extend(non_abstract_subclasses(subclass, to_skip=_to_skip))
|
||||
return subclasses
|
||||
|
||||
|
||||
_PARSERS_TO_SKIP = {"FakeOutputParser", "BaseOutputParser"}
|
||||
_NON_ABSTRACT_PARSERS = non_abstract_subclasses(
|
||||
BaseOutputParser, to_skip=_PARSERS_TO_SKIP
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", _NON_ABSTRACT_PARSERS)
|
||||
def test_subclass_implements_type(cls: Type[BaseOutputParser]) -> None:
|
||||
try:
|
||||
cls._type
|
||||
except NotImplementedError:
|
||||
pytest.fail(f"_type property is not implemented in class {cls.__name__}")
|
||||
|
||||
|
||||
def test_all_subclasses_implement_unique_type() -> None:
|
||||
types = []
|
||||
for cls in _NON_ABSTRACT_PARSERS:
|
||||
try:
|
||||
types.append(cls._type)
|
||||
except NotImplementedError:
|
||||
# This is handled in the previous test
|
||||
pass
|
||||
dups = set([t for t in types if types.count(t) > 1])
|
||||
assert not dups, f"Duplicate types: {dups}"
|
Loading…
Reference in New Issue
Block a user