diff --git a/langchain/schema.py b/langchain/schema.py deleted file mode 100644 index 4a04bd04c5f..00000000000 --- a/langchain/schema.py +++ /dev/null @@ -1,401 +0,0 @@ -"""Common schema objects.""" -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import ( - Any, - Dict, - Generic, - List, - NamedTuple, - Optional, - Sequence, - TypeVar, - Union, -) - -from pydantic import BaseModel, Extra, Field, root_validator - - -def get_buffer_string( - messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" -) -> str: - """Get buffer string of messages.""" - string_messages = [] - for m in messages: - if isinstance(m, HumanMessage): - role = human_prefix - elif isinstance(m, AIMessage): - role = ai_prefix - elif isinstance(m, SystemMessage): - role = "System" - elif isinstance(m, ChatMessage): - role = m.role - else: - raise ValueError(f"Got unsupported message type: {m}") - string_messages.append(f"{role}: {m.content}") - return "\n".join(string_messages) - - -class AgentAction(NamedTuple): - """Agent's action to take.""" - - tool: str - tool_input: Union[str, dict] - log: str - - -class AgentFinish(NamedTuple): - """Agent's return value.""" - - return_values: dict - log: str - - -class Generation(BaseModel): - """Output of a single generation.""" - - text: str - """Generated text output.""" - - generation_info: Optional[Dict[str, Any]] = None - """Raw generation info response from the provider""" - """May include things like reason for finishing (e.g. in OpenAI)""" - # TODO: add log probs - - -class BaseMessage(BaseModel): - """Message object.""" - - content: str - additional_kwargs: dict = Field(default_factory=dict) - - @property - @abstractmethod - def type(self) -> str: - """Type of the message, used for serialization.""" - - -class HumanMessage(BaseMessage): - """Type of message that is spoken by the human.""" - - example: bool = False - - @property - def type(self) -> str: - """Type of the message, used for serialization.""" - return "human" - - -class AIMessage(BaseMessage): - """Type of message that is spoken by the AI.""" - - example: bool = False - - @property - def type(self) -> str: - """Type of the message, used for serialization.""" - return "ai" - - -class SystemMessage(BaseMessage): - """Type of message that is a system message.""" - - @property - def type(self) -> str: - """Type of the message, used for serialization.""" - return "system" - - -class ChatMessage(BaseMessage): - """Type of message with arbitrary speaker.""" - - role: str - - @property - def type(self) -> str: - """Type of the message, used for serialization.""" - return "chat" - - -def _message_to_dict(message: BaseMessage) -> dict: - return {"type": message.type, "data": message.dict()} - - -def messages_to_dict(messages: List[BaseMessage]) -> List[dict]: - return [_message_to_dict(m) for m in messages] - - -def _message_from_dict(message: dict) -> BaseMessage: - _type = message["type"] - if _type == "human": - return HumanMessage(**message["data"]) - elif _type == "ai": - return AIMessage(**message["data"]) - elif _type == "system": - return SystemMessage(**message["data"]) - elif _type == "chat": - return ChatMessage(**message["data"]) - else: - raise ValueError(f"Got unexpected type: {_type}") - - -def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: - return [_message_from_dict(m) for m in messages] - - -class ChatGeneration(Generation): - """Output of a single generation.""" - - text = "" - message: BaseMessage - - @root_validator - def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]: - values["text"] = values["message"].content - return values - - -class ChatResult(BaseModel): - """Class that contains all relevant information for a Chat Result.""" - - generations: List[ChatGeneration] - """List of the things generated.""" - llm_output: Optional[dict] = None - """For arbitrary LLM provider specific output.""" - - -class LLMResult(BaseModel): - """Class that contains all relevant information for an LLM Result.""" - - generations: List[List[Generation]] - """List of the things generated. This is List[List[]] because - each input could have multiple generations.""" - llm_output: Optional[dict] = None - """For arbitrary LLM provider specific output.""" - - -class PromptValue(BaseModel, ABC): - @abstractmethod - def to_string(self) -> str: - """Return prompt as string.""" - - @abstractmethod - def to_messages(self) -> List[BaseMessage]: - """Return prompt as messages.""" - - -class BaseMemory(BaseModel, ABC): - """Base interface for memory in chains.""" - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @property - @abstractmethod - def memory_variables(self) -> List[str]: - """Input keys this memory class will load dynamically.""" - - @abstractmethod - def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - """Return key-value pairs given the text input to the chain. - - If None, return all memories - """ - - @abstractmethod - def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: - """Save the context of this model run to memory.""" - - @abstractmethod - def clear(self) -> None: - """Clear memory contents.""" - - -class BaseChatMessageHistory(ABC): - """Base interface for chat message history - See `ChatMessageHistory` for default implementation. - """ - - """ - Example: - .. code-block:: python - - class FileChatMessageHistory(BaseChatMessageHistory): - storage_path: str - session_id: str - - @property - def messages(self): - with open(os.path.join(storage_path, session_id), 'r:utf-8') as f: - messages = json.loads(f.read()) - return messages_from_dict(messages) - - def add_message(self, message: BaseMessage) -> None: - messages = self.messages.append(_message_to_dict(message)) - with open(os.path.join(storage_path, session_id), 'w') as f: - json.dump(f, messages) - - def clear(self): - with open(os.path.join(storage_path, session_id), 'w') as f: - f.write("[]") - """ - - messages: List[BaseMessage] - - def add_user_message(self, message: str) -> None: - """Add a user message to the store""" - self.add_message(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - """Add an AI message to the store""" - self.add_message(AIMessage(content=message)) - - def add_message(self, message: BaseMessage) -> None: - """Add a self-created message to the store""" - raise NotImplementedError - - @abstractmethod - def clear(self) -> None: - """Remove all messages from the store""" - - -class Document(BaseModel): - """Interface for interacting with a document.""" - - page_content: str - metadata: dict = Field(default_factory=dict) - - -class BaseRetriever(ABC): - @abstractmethod - def get_relevant_documents(self, query: str) -> List[Document]: - """Get documents relevant for a query. - - Args: - query: string to find relevant documents for - - Returns: - List of relevant documents - """ - - @abstractmethod - async def aget_relevant_documents(self, query: str) -> List[Document]: - """Get documents relevant for a query. - - Args: - query: string to find relevant documents for - - Returns: - List of relevant documents - """ - - -# For backwards compatibility - - -Memory = BaseMemory - -T = TypeVar("T") - - -class BaseOutputParser(BaseModel, ABC, Generic[T]): - """Class to parse the output of an LLM call. - - Output parsers help structure language model responses. - """ - - @abstractmethod - def parse(self, text: str) -> T: - """Parse the output of an LLM call. - - A method which takes in a string (assumed output of a language model ) - and parses it into some structure. - - Args: - text: output of language model - - Returns: - structured output - """ - - def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: - """Optional method to parse the output of an LLM call with a prompt. - - The prompt is largely provided in the event the OutputParser wants - to retry or fix the output in some way, and needs information from - the prompt to do so. - - Args: - completion: output of language model - prompt: prompt value - - Returns: - structured output - """ - return self.parse(completion) - - def get_format_instructions(self) -> str: - """Instructions on how the LLM output should be formatted.""" - raise NotImplementedError - - @property - def _type(self) -> str: - """Return the type key.""" - raise NotImplementedError( - f"_type property is not implemented in class {self.__class__.__name__}." - " This is required for serialization." - ) - - def dict(self, **kwargs: Any) -> Dict: - """Return dictionary representation of output parser.""" - output_parser_dict = super().dict() - output_parser_dict["_type"] = self._type - return output_parser_dict - - -class OutputParserException(ValueError): - """Exception that output parsers should raise to signify a parsing error. - - This exists to differentiate parsing errors from other code or execution errors - that also may arise inside the output parser. OutputParserExceptions will be - available to catch and handle in ways to fix the parsing error, while other - errors will be raised. - """ - - def __init__( - self, - error: Any, - observation: str | None = None, - llm_output: str | None = None, - send_to_llm: bool = False, - ): - super(OutputParserException, self).__init__(error) - if send_to_llm: - if observation is None or llm_output is None: - raise ValueError( - "Arguments 'observation' & 'llm_output'" - " are required if 'send_to_llm' is True" - ) - self.observation = observation - self.llm_output = llm_output - self.send_to_llm = send_to_llm - - -class BaseDocumentTransformer(ABC): - """Base interface for transforming documents.""" - - @abstractmethod - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - """Transform a list of documents.""" - - @abstractmethod - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - """Asynchronously transform a list of documents.""" diff --git a/langchain/schema/__init__.py b/langchain/schema/__init__.py new file mode 100644 index 00000000000..c8464079b3a --- /dev/null +++ b/langchain/schema/__init__.py @@ -0,0 +1,48 @@ +from langchain.schema.agents import AgentAction, AgentFinish +from langchain.schema.chat_message_history import BaseChatMessageHistory +from langchain.schema.documents import BaseDocumentTransformer, Document +from langchain.schema.memory import BaseMemory, Memory +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, + _message_from_dict, + _message_to_dict, + get_buffer_string, + messages_from_dict, + messages_to_dict, +) +from langchain.schema.output_parser import BaseOutputParser, OutputParserException +from langchain.schema.outputs import ChatGeneration, ChatResult, Generation, LLMResult +from langchain.schema.prompts import PromptValue +from langchain.schema.retriever import BaseRetriever + +__all__ = [ + "AgentAction", + "AgentFinish", + "BaseMessage", + "BaseChatMessageHistory", + "Document", + "BaseDocumentTransformer", + "Memory", + "BaseMemory", + "ChatResult", + "ChatGeneration", + "ChatMessage", + "HumanMessage", + "AIMessage", + "SystemMessage", + "get_buffer_string", + "messages_to_dict", + "messages_from_dict", + "BaseOutputParser", + "OutputParserException", + "PromptValue", + "Generation", + "LLMResult", + "BaseRetriever", + "_message_to_dict", + "_message_from_dict", +] diff --git a/langchain/schema/agents.py b/langchain/schema/agents.py new file mode 100644 index 00000000000..418144ccfa2 --- /dev/null +++ b/langchain/schema/agents.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import NamedTuple, Union + + +class AgentAction(NamedTuple): + """Agent's action to take.""" + + tool: str + tool_input: Union[str, dict] + log: str + + +class AgentFinish(NamedTuple): + """Agent's return value.""" + + return_values: dict + log: str diff --git a/langchain/schema/chat_message_history.py b/langchain/schema/chat_message_history.py new file mode 100644 index 00000000000..b5bb81be9a6 --- /dev/null +++ b/langchain/schema/chat_message_history.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List + +from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage + + +class BaseChatMessageHistory(ABC): + """Base interface for chat message history + See `ChatMessageHistory` for default implementation. + """ + + """ + Example: + .. code-block:: python + + class FileChatMessageHistory(BaseChatMessageHistory): + storage_path: str + session_id: str + + @property + def messages(self): + with open(os.path.join(storage_path, session_id), 'r:utf-8') as f: + messages = json.loads(f.read()) + return messages_from_dict(messages) + + def add_message(self, message: BaseMessage) -> None: + messages = self.messages.append(_message_to_dict(message)) + with open(os.path.join(storage_path, session_id), 'w') as f: + json.dump(f, messages) + + def clear(self): + with open(os.path.join(storage_path, session_id), 'w') as f: + f.write("[]") + """ + + messages: List[BaseMessage] + + def add_user_message(self, message: str) -> None: + """Add a user message to the store""" + self.add_message(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + """Add an AI message to the store""" + self.add_message(AIMessage(content=message)) + + def add_message(self, message: BaseMessage) -> None: + """Add a self-created message to the store""" + raise NotImplementedError + + @abstractmethod + def clear(self) -> None: + """Remove all messages from the store""" diff --git a/langchain/schema/documents.py b/langchain/schema/documents.py new file mode 100644 index 00000000000..e28df9aab18 --- /dev/null +++ b/langchain/schema/documents.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Sequence + +from pydantic import BaseModel, Field + + +class BaseDocumentTransformer(ABC): + """Base interface for transforming documents.""" + + @abstractmethod + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Transform a list of documents.""" + + @abstractmethod + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Asynchronously transform a list of documents.""" + + +class Document(BaseModel): + """Interface for interacting with a document.""" + + page_content: str + metadata: dict = Field(default_factory=dict) diff --git a/langchain/schema/memory.py b/langchain/schema/memory.py new file mode 100644 index 00000000000..956d7c153d6 --- /dev/null +++ b/langchain/schema/memory.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from pydantic import BaseModel, Extra + + +class BaseMemory(BaseModel, ABC): + """Base interface for memory in chains.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + @abstractmethod + def memory_variables(self) -> List[str]: + """Input keys this memory class will load dynamically.""" + + @abstractmethod + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Return key-value pairs given the text input to the chain. + + If None, return all memories + """ + + @abstractmethod + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """Save the context of this model run to memory.""" + + @abstractmethod + def clear(self) -> None: + """Clear memory contents.""" + + +Memory = BaseMemory diff --git a/langchain/schema/messages.py b/langchain/schema/messages.py new file mode 100644 index 00000000000..c83906d2116 --- /dev/null +++ b/langchain/schema/messages.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import List + +from pydantic import BaseModel, Field + + +class BaseMessage(BaseModel): + """Message object.""" + + content: str + additional_kwargs: dict = Field(default_factory=dict) + + @property + @abstractmethod + def type(self) -> str: + """Type of the message, used for serialization.""" + + +class HumanMessage(BaseMessage): + """Type of message that is spoken by the human.""" + + example: bool = False + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "human" + + +class AIMessage(BaseMessage): + """Type of message that is spoken by the AI.""" + + example: bool = False + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "ai" + + +class SystemMessage(BaseMessage): + """Type of message that is a system message.""" + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "system" + + +class ChatMessage(BaseMessage): + """Type of message with arbitrary speaker.""" + + role: str + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "chat" + + +def get_buffer_string( + messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" +) -> str: + """Get buffer string of messages.""" + string_messages = [] + for m in messages: + if isinstance(m, HumanMessage): + role = human_prefix + elif isinstance(m, AIMessage): + role = ai_prefix + elif isinstance(m, SystemMessage): + role = "System" + elif isinstance(m, ChatMessage): + role = m.role + else: + raise ValueError(f"Got unsupported message type: {m}") + string_messages.append(f"{role}: {m.content}") + return "\n".join(string_messages) + + +def _message_to_dict(message: BaseMessage) -> dict: + return {"type": message.type, "data": message.dict()} + + +def messages_to_dict(messages: List[BaseMessage]) -> List[dict]: + return [_message_to_dict(m) for m in messages] + + +def _message_from_dict(message: dict) -> BaseMessage: + _type = message["type"] + if _type == "human": + return HumanMessage(**message["data"]) + elif _type == "ai": + return AIMessage(**message["data"]) + elif _type == "system": + return SystemMessage(**message["data"]) + elif _type == "chat": + return ChatMessage(**message["data"]) + else: + raise ValueError(f"Got unexpected type: {_type}") + + +def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: + return [_message_from_dict(m) for m in messages] diff --git a/langchain/schema/output_parser.py b/langchain/schema/output_parser.py new file mode 100644 index 00000000000..e759e1c8541 --- /dev/null +++ b/langchain/schema/output_parser.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, TypeVar + +from pydantic import BaseModel + +from langchain.schema.prompts import PromptValue + +T = TypeVar("T") + + +class BaseOutputParser(BaseModel, ABC, Generic[T]): + """Class to parse the output of an LLM call. + + Output parsers help structure language model responses. + """ + + @abstractmethod + def parse(self, text: str) -> T: + """Parse the output of an LLM call. + + A method which takes in a string (assumed output of a language model ) + and parses it into some structure. + + Args: + text: output of language model + + Returns: + structured output + """ + + def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: + """Optional method to parse the output of an LLM call with a prompt. + + The prompt is largely provided in the event the OutputParser wants + to retry or fix the output in some way, and needs information from + the prompt to do so. + + Args: + completion: output of language model + prompt: prompt value + + Returns: + structured output + """ + return self.parse(completion) + + def get_format_instructions(self) -> str: + """Instructions on how the LLM output should be formatted.""" + raise NotImplementedError + + @property + def _type(self) -> str: + """Return the type key.""" + raise NotImplementedError( + f"_type property is not implemented in class {self.__class__.__name__}." + " This is required for serialization." + ) + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of output parser.""" + output_parser_dict = super().dict() + output_parser_dict["_type"] = self._type + return output_parser_dict + + +class OutputParserException(ValueError): + """Exception that output parsers should raise to signify a parsing error. + + This exists to differentiate parsing errors from other code or execution errors + that also may arise inside the output parser. OutputParserExceptions will be + available to catch and handle in ways to fix the parsing error, while other + errors will be raised. + """ + + def __init__( + self, + error: Any, + observation: str | None = None, + llm_output: str | None = None, + send_to_llm: bool = False, + ): + super(OutputParserException, self).__init__(error) + if send_to_llm: + if observation is None or llm_output is None: + raise ValueError( + "Arguments 'observation' & 'llm_output'" + " are required if 'send_to_llm' is True" + ) + self.observation = observation + self.llm_output = llm_output + self.send_to_llm = send_to_llm diff --git a/langchain/schema/outputs.py b/langchain/schema/outputs.py new file mode 100644 index 00000000000..daf90f0c082 --- /dev/null +++ b/langchain/schema/outputs.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, root_validator + +from langchain.schema.messages import BaseMessage + + +class Generation(BaseModel): + """Output of a single generation.""" + + text: str + """Generated text output.""" + + generation_info: Optional[Dict[str, Any]] = None + """Raw generation info response from the provider""" + """May include things like reason for finishing (e.g. in OpenAI)""" + # TODO: add log probs + + +class ChatGeneration(Generation): + """Output of a single generation.""" + + text = "" + message: BaseMessage + + @root_validator + def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]: + values["text"] = values["message"].content + return values + + +class ChatResult(BaseModel): + """Class that contains all relevant information for a Chat Result.""" + + generations: List[ChatGeneration] + """List of the things generated.""" + llm_output: Optional[dict] = None + """For arbitrary LLM provider specific output.""" + + +class LLMResult(BaseModel): + """Class that contains all relevant information for an LLM Result.""" + + generations: List[List[Generation]] + """List of the things generated. This is List[List[]] because + each input could have multiple generations.""" + llm_output: Optional[dict] = None + """For arbitrary LLM provider specific output.""" diff --git a/langchain/schema/prompts.py b/langchain/schema/prompts.py new file mode 100644 index 00000000000..4fcb5c51a87 --- /dev/null +++ b/langchain/schema/prompts.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List + +from pydantic import BaseModel + +from langchain.schema.messages import BaseMessage + + +class PromptValue(BaseModel, ABC): + @abstractmethod + def to_string(self) -> str: + """Return prompt as string.""" + + @abstractmethod + def to_messages(self) -> List[BaseMessage]: + """Return prompt as messages.""" diff --git a/langchain/schema/retriever.py b/langchain/schema/retriever.py new file mode 100644 index 00000000000..6cf3fe164c9 --- /dev/null +++ b/langchain/schema/retriever.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List + +from langchain.schema.documents import Document + + +class BaseRetriever(ABC): + @abstractmethod + def get_relevant_documents(self, query: str) -> List[Document]: + """Get documents relevant for a query. + + Args: + query: string to find relevant documents for + + Returns: + List of relevant documents + """ + + @abstractmethod + async def aget_relevant_documents(self, query: str) -> List[Document]: + """Get documents relevant for a query. + + Args: + query: string to find relevant documents for + + Returns: + List of relevant documents + """