mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
Add input/output schemas to runnables (#11063)
This adds `input_schema` and `output_schema` properties to all runnables, which are Pydantic models for the input and output types respectively. These are inferred from the structure of the Runnable as much as possible, the only manual typing needed is - optionally add type hints to lambdas (which get translated to input/output schemas) - optionally add type hint to RunnablePassthrough These schemas can then be used to create JSON Schema descriptions of input and output types, see the tests - [x] Ensure no InputType and OutputType in our classes use abstract base classes (replace with union of subclasses) - [x] Implement in BaseChain and LLMChain - [x] Implement in RunnableBranch - [x] Implement in RunnableBinding, RunnableMap, RunnablePassthrough, RunnableEach, RunnableRouter - [x] Implement in LLM, Prompt, Chat Model, Output Parser, Retriever - [x] Implement in RunnableLambda from function signature - [x] Implement in Tool <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
b05bb9e136
commit
cfa2203c62
@ -7,7 +7,7 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@ -22,7 +22,13 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import Field, root_validator, validator
|
||||
from langchain.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
create_model,
|
||||
root_validator,
|
||||
validator,
|
||||
)
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
|
||||
@ -56,6 +62,20 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
chains and cannot return as rich of an output as `__call__`.
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"ChainInput", **{k: (Any, None) for k in self.input_keys}
|
||||
)
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
@ -28,6 +28,20 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
input_key: str = "input_documents" #: :meta private:
|
||||
output_key: str = "output_text" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return create_model(
|
||||
"CombineDocumentsInput",
|
||||
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
|
||||
)
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return create_model(
|
||||
"CombineDocumentsOutput",
|
||||
**{self.output_key: (str, None)}, # type: ignore[call-overload]
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
@ -153,6 +167,17 @@ class AnalyzeDocumentChain(Chain):
|
||||
"""
|
||||
return self.combine_docs_chain.output_keys
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return create_model(
|
||||
"AnalyzeDocumentChain",
|
||||
**{self.input_key: (str, None)}, # type: ignore[call-overload]
|
||||
)
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.combine_docs_chain.output_schema
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
|
@ -9,7 +9,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
|
||||
|
||||
|
||||
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
@ -98,6 +98,19 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
return_intermediate_steps: bool = False
|
||||
"""Return the results of the map steps in the output."""
|
||||
|
||||
@property
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
if self.return_intermediate_steps:
|
||||
return create_model(
|
||||
"MapReduceDocumentsOutput",
|
||||
**{
|
||||
self.output_key: (str, None),
|
||||
"intermediate_steps": (List[str], None),
|
||||
}, # type: ignore[call-overload]
|
||||
)
|
||||
|
||||
return super().output_schema
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
@ -9,7 +9,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
|
||||
|
||||
|
||||
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
@ -77,6 +77,18 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
schema: Dict[str, Any] = {
|
||||
self.output_key: (str, None),
|
||||
}
|
||||
if self.return_intermediate_steps:
|
||||
schema["intermediate_steps"] = (List[str], None)
|
||||
if self.metadata_keys:
|
||||
schema.update({key: (Any, None) for key in self.metadata_keys})
|
||||
|
||||
return create_model("MapRerankOutput", **schema)
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
@ -11,6 +11,7 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@ -37,9 +38,14 @@ from langchain.schema import (
|
||||
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
@ -107,6 +113,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
"""Get the input type for this runnable."""
|
||||
return Union[
|
||||
HumanMessageChunk,
|
||||
AIMessageChunk,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
SystemMessageChunk,
|
||||
]
|
||||
|
||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(input, PromptValue):
|
||||
return input
|
||||
|
@ -38,6 +38,8 @@ from langchain.schema.messages import (
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
@ -53,39 +55,6 @@ class ChatLiteLLMException(Exception):
|
||||
"""Error with the `LiteLLM I/O` library"""
|
||||
|
||||
|
||||
def _truncate_at_stop_tokens(
|
||||
text: str,
|
||||
stop: Optional[List[str]],
|
||||
) -> str:
|
||||
"""Truncates text at the earliest stop token found."""
|
||||
if stop is None:
|
||||
return text
|
||||
|
||||
for stop_token in stop:
|
||||
stop_token_idx = text.find(stop_token)
|
||||
if stop_token_idx != -1:
|
||||
text = text[:stop_token_idx]
|
||||
return text
|
||||
|
||||
|
||||
class FunctionMessage(BaseMessage):
|
||||
"""Message for passing the result of executing a function back to a model."""
|
||||
|
||||
name: str
|
||||
"""The name of the function that was executed."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "function"
|
||||
|
||||
|
||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
"""Message Chunk for passing the result of executing a function back to a model."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatLiteLLM,
|
||||
run_manager: Optional[
|
||||
|
@ -199,6 +199,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[str]:
|
||||
"""Get the input type for this runnable."""
|
||||
return str
|
||||
|
||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(input, PromptValue):
|
||||
return input
|
||||
|
@ -28,6 +28,7 @@ from langchain.schema import (
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
@ -280,7 +281,7 @@ class ChatPromptValue(PromptValue):
|
||||
A type of a prompt value that is built from messages.
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage]
|
||||
messages: Sequence[BaseMessage]
|
||||
"""List of messages."""
|
||||
|
||||
def to_string(self) -> str:
|
||||
@ -289,7 +290,14 @@ class ChatPromptValue(PromptValue):
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of messages."""
|
||||
return self.messages
|
||||
return list(self.messages)
|
||||
|
||||
|
||||
class ChatPromptValueConcrete(ChatPromptValue):
|
||||
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||
For use in external schemas."""
|
||||
|
||||
messages: Sequence[AnyMessage]
|
||||
|
||||
|
||||
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
|
@ -13,8 +13,10 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||
from langchain.schema.output import LLMResult
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.runnable import Runnable
|
||||
@ -70,6 +72,21 @@ class BaseLanguageModel(
|
||||
Each of these has an equivalent asynchronous method.
|
||||
"""
|
||||
|
||||
@property
|
||||
def InputType(self) -> TypeAlias:
|
||||
"""Get the input type for this runnable."""
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValueConcrete
|
||||
|
||||
# This is a version of LanguageModelInput which replaces the abstract
|
||||
# base class BaseMessage with a union of its subclasses, which makes
|
||||
# for a much better schema.
|
||||
return Union[
|
||||
str,
|
||||
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||
List[AnyMessage],
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self,
|
||||
|
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.pydantic_v1 import Extra, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
@ -69,10 +70,10 @@ class BaseMessage(Serializable):
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
"""Any additional information."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
"""Type of the Message, used for serialization."""
|
||||
type: str
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
@ -147,10 +148,10 @@ class HumanMessage(BaseMessage):
|
||||
conversation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "human"
|
||||
type: Literal["human"] = "human"
|
||||
|
||||
|
||||
HumanMessage.update_forward_refs()
|
||||
|
||||
|
||||
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||
@ -167,10 +168,10 @@ class AIMessage(BaseMessage):
|
||||
conversation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "ai"
|
||||
type: Literal["ai"] = "ai"
|
||||
|
||||
|
||||
AIMessage.update_forward_refs()
|
||||
|
||||
|
||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
@ -199,10 +200,10 @@ class SystemMessage(BaseMessage):
|
||||
of input messages.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "system"
|
||||
type: Literal["system"] = "system"
|
||||
|
||||
|
||||
SystemMessage.update_forward_refs()
|
||||
|
||||
|
||||
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||
@ -217,10 +218,10 @@ class FunctionMessage(BaseMessage):
|
||||
name: str
|
||||
"""The name of the function that was executed."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "function"
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
|
||||
FunctionMessage.update_forward_refs()
|
||||
|
||||
|
||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
@ -250,10 +251,10 @@ class ChatMessage(BaseMessage):
|
||||
role: str
|
||||
"""The speaker / role of the Message."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "chat"
|
||||
type: Literal["chat"] = "chat"
|
||||
|
||||
|
||||
ChatMessage.update_forward_refs()
|
||||
|
||||
|
||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
@ -277,6 +278,9 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
return super().__add__(other)
|
||||
|
||||
|
||||
AnyMessage = Union[AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage]
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> dict:
|
||||
return {"type": message.type, "data": message.dict()}
|
||||
|
||||
|
@ -14,8 +14,10 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import get_args
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import AnyMessage, BaseMessage
|
||||
from langchain.schema.output import ChatGeneration, Generation
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
@ -58,6 +60,16 @@ class BaseGenerationOutputParser(
|
||||
):
|
||||
"""Base class to parse the output of an LLM call."""
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
return Union[str, AnyMessage]
|
||||
|
||||
@property
|
||||
def OutputType(self) -> type[T]:
|
||||
# even though mypy complains this isn't valid,
|
||||
# it is good enough for pydantic to build the schema from
|
||||
return T # type: ignore[misc]
|
||||
|
||||
def invoke(
|
||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||
) -> T:
|
||||
@ -129,6 +141,22 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
return "boolean_output_parser"
|
||||
""" # noqa: E501
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
return Union[str, AnyMessage]
|
||||
|
||||
@property
|
||||
def OutputType(self) -> type[T]:
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 1:
|
||||
return type_args[0]
|
||||
|
||||
raise TypeError(
|
||||
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
|
||||
"Override the OutputType property to specify the output type."
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||
) -> T:
|
||||
|
@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
||||
import yaml
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.output_parser import BaseOutputParser
|
||||
from langchain.schema.prompt import PromptValue
|
||||
@ -36,6 +36,20 @@ class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValueConcrete
|
||||
|
||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"PromptInput", **{k: (Any, None) for k in self.input_variables}
|
||||
)
|
||||
|
||||
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.format_prompt(
|
||||
|
@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
|
||||
from concurrent.futures import FIRST_COMPLETED, wait
|
||||
from functools import partial
|
||||
from itertools import tee
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -27,6 +28,8 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import get_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -37,7 +40,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain.schema.runnable.config import (
|
||||
RunnableConfig,
|
||||
acall_func_with_variable_args,
|
||||
@ -55,6 +58,7 @@ from langchain.schema.runnable.utils import (
|
||||
accepts_config,
|
||||
accepts_run_manager,
|
||||
gather_with_concurrency,
|
||||
get_function_first_arg_dict_keys,
|
||||
)
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
from langchain.utils.iter import safetee
|
||||
@ -66,6 +70,52 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""A Runnable is a unit of work that can be invoked, batched, streamed, or
|
||||
transformed."""
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 2:
|
||||
return type_args[0]
|
||||
|
||||
raise TypeError(
|
||||
f"Runnable {self.__class__.__name__} doesn't have an inferable InputType. "
|
||||
"Override the InputType property to specify the input type."
|
||||
)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
type_args = get_args(cls)
|
||||
if type_args and len(type_args) == 2:
|
||||
return type_args[1]
|
||||
|
||||
raise TypeError(
|
||||
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
|
||||
"Override the OutputType property to specify the output type."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
root_type = self.InputType
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.__class__.__name__ + "Input", __root__=(root_type, None)
|
||||
)
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
root_type = self.OutputType
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.__class__.__name__ + "Output", __root__=(root_type, None)
|
||||
)
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
@ -849,6 +899,20 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
|
||||
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
runnables = (
|
||||
[self.default]
|
||||
+ [r for _, r in self.branches]
|
||||
+ [r for r, _ in self.branches]
|
||||
)
|
||||
|
||||
for runnable in runnables:
|
||||
if runnable.input_schema.schema().get("type") is not None:
|
||||
return runnable.input_schema
|
||||
|
||||
return super().input_schema
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
"""First evaluates the condition, then delegate to true or false branch."""
|
||||
config = ensure_config(config)
|
||||
@ -953,6 +1017,22 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
return self.runnable.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.runnable.OutputType
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return self.runnable.input_schema
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.runnable.output_schema
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
@ -1202,6 +1282,22 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
return self.first.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.last.OutputType
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return self.first.input_schema
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.last.output_schema
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
@ -1692,6 +1788,37 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
for step in self.steps.values():
|
||||
if step.InputType:
|
||||
return step.InputType
|
||||
|
||||
return Any
|
||||
|
||||
@property
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
if all(not s.input_schema.__custom_root_type__ for s in self.steps.values()):
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableMapInput",
|
||||
**{
|
||||
k: (v.type_, v.default)
|
||||
for step in self.steps.values()
|
||||
for k, v in step.input_schema.__fields__.items()
|
||||
},
|
||||
)
|
||||
|
||||
return super().input_schema
|
||||
|
||||
@property
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableMapOutput",
|
||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
@ -1942,6 +2069,59 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
f"Instead got an unsupported type: {type(func)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||
try:
|
||||
params = inspect.signature(func).parameters
|
||||
first_param = next(iter(params.values()), None)
|
||||
if first_param and first_param.annotation != inspect.Parameter.empty:
|
||||
return first_param.annotation
|
||||
else:
|
||||
return Any
|
||||
except ValueError:
|
||||
return Any
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||
|
||||
if isinstance(func, itemgetter):
|
||||
# This is terrible, but afaict it's not possible to access _items
|
||||
# on itemgetter objects, so we have to parse the repr
|
||||
items = str(func).replace("operator.itemgetter(", "")[:-1].split(", ")
|
||||
if all(
|
||||
item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items
|
||||
):
|
||||
# It's a dict, lol
|
||||
return create_model(
|
||||
"RunnableLambdaInput",
|
||||
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
|
||||
)
|
||||
else:
|
||||
return create_model("RunnableLambdaInput", __root__=(List[Any], None))
|
||||
|
||||
if dict_keys := get_function_first_arg_dict_keys(func):
|
||||
return create_model(
|
||||
"RunnableLambdaInput",
|
||||
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
||||
)
|
||||
|
||||
return super().input_schema
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
return (
|
||||
sig.return_annotation
|
||||
if sig.return_annotation != inspect.Signature.empty
|
||||
else Any
|
||||
)
|
||||
except ValueError:
|
||||
return Any
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, RunnableLambda):
|
||||
if hasattr(self, "func") and hasattr(other, "func"):
|
||||
@ -2068,6 +2248,34 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
return List[self.bound.InputType] # type: ignore[name-defined]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
return create_model(
|
||||
"RunnableEachInput",
|
||||
__root__=(
|
||||
List[self.bound.input_schema], # type: ignore[name-defined]
|
||||
None,
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> type[List[Output]]:
|
||||
return List[self.bound.OutputType] # type: ignore[name-defined]
|
||||
|
||||
@property
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
return create_model(
|
||||
"RunnableEachOutput",
|
||||
__root__=(
|
||||
List[self.bound.output_schema], # type: ignore[name-defined]
|
||||
None,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
@ -2124,6 +2332,22 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def InputType(self) -> type[Input]:
|
||||
return self.bound.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> type[Output]:
|
||||
return self.bound.OutputType
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return self.bound.input_schema
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.bound.output_schema
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional, Type
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable.base import Input, Runnable
|
||||
@ -20,6 +20,8 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
A runnable that passes through the input.
|
||||
"""
|
||||
|
||||
input_type: Optional[Type[Input]] = None
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
@ -28,6 +30,14 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
return self.input_type or Any
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
return self.input_type or Any
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
return self._call_with_config(identity, input, config)
|
||||
|
||||
|
@ -4,16 +4,16 @@ from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable.base import (
|
||||
Input,
|
||||
@ -43,21 +43,17 @@ class RouterInput(TypedDict):
|
||||
input: Any
|
||||
|
||||
|
||||
class RouterRunnable(
|
||||
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
||||
):
|
||||
class RouterRunnable(Serializable, Runnable[RouterInput, Output]):
|
||||
"""
|
||||
A runnable that routes to a set of runnables based on Input['key'].
|
||||
Returns the output of the selected runnable.
|
||||
"""
|
||||
|
||||
runnables: Mapping[str, Runnable[Input, Output]]
|
||||
runnables: Mapping[str, Runnable[Any, Output]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnables: Mapping[
|
||||
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
|
||||
],
|
||||
runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
|
||||
|
@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import inspect
|
||||
import textwrap
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Coroutine, TypeVar, Union
|
||||
from typing import Any, Callable, Coroutine, List, Optional, Set, TypeVar, Union
|
||||
|
||||
Input = TypeVar("Input")
|
||||
# Output type should implement __concat__, as eg str, list, dict do
|
||||
@ -35,3 +38,61 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
|
||||
return signature(callable).parameters.get("config") is not None
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class IsLocalDict(ast.NodeVisitor):
|
||||
def __init__(self, name: str, keys: Set[str]) -> None:
|
||||
self.name = name
|
||||
self.keys = keys
|
||||
|
||||
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
||||
if (
|
||||
isinstance(node.ctx, ast.Load)
|
||||
and isinstance(node.value, ast.Name)
|
||||
and node.value.id == self.name
|
||||
and isinstance(node.slice, ast.Constant)
|
||||
and isinstance(node.slice.value, str)
|
||||
):
|
||||
# we've found a subscript access on the name we're looking for
|
||||
self.keys.add(node.slice.value)
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> Any:
|
||||
if (
|
||||
isinstance(node.func, ast.Attribute)
|
||||
and isinstance(node.func.value, ast.Name)
|
||||
and node.func.value.id == self.name
|
||||
and node.func.attr == "get"
|
||||
and len(node.args) in (1, 2)
|
||||
and isinstance(node.args[0], ast.Constant)
|
||||
and isinstance(node.args[0].value, str)
|
||||
):
|
||||
# we've found a .get() call on the name we're looking for
|
||||
self.keys.add(node.args[0].value)
|
||||
|
||||
|
||||
class IsFunctionArgDict(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
self.keys: Set[str] = set()
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
input_arg_name = node.args.args[0].arg
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node.body)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
input_arg_name = node.args.args[0].arg
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||
input_arg_name = node.args.args[0].arg
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||
|
||||
|
||||
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
visitor = IsFunctionArgDict()
|
||||
visitor.visit(tree)
|
||||
return list(visitor.keys) if visitor.keys else None
|
||||
except (TypeError, OSError):
|
||||
return None
|
||||
|
@ -187,6 +187,14 @@ class ChildTool(BaseTool):
|
||||
|
||||
# --- Runnable ---
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
"""The tool's input schema."""
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema
|
||||
else:
|
||||
return create_schema_from_function(self.name, self._run)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,3 +1,4 @@
|
||||
import sys
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union, cast
|
||||
from uuid import UUID
|
||||
@ -12,6 +13,8 @@ from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.summarize import load_summarize_chain
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
|
||||
from langchain.load.dump import dumpd, dumps
|
||||
@ -43,6 +46,7 @@ from langchain.schema.runnable import (
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
@ -115,6 +119,412 @@ class FakeRetriever(BaseRetriever):
|
||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||
|
||||
|
||||
def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
fake = FakeRunnable() # str -> int
|
||||
|
||||
assert fake.input_schema.schema() == {
|
||||
"title": "FakeRunnableInput",
|
||||
"type": "string",
|
||||
}
|
||||
assert fake.output_schema.schema() == {
|
||||
"title": "FakeRunnableOutput",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
fake_bound = FakeRunnable().bind(a="b") # str -> int
|
||||
|
||||
assert fake_bound.input_schema.schema() == {
|
||||
"title": "FakeRunnableInput",
|
||||
"type": "string",
|
||||
}
|
||||
assert fake_bound.output_schema.schema() == {
|
||||
"title": "FakeRunnableOutput",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
fake_w_fallbacks = FakeRunnable().with_fallbacks((fake,)) # str -> int
|
||||
|
||||
assert fake_w_fallbacks.input_schema.schema() == {
|
||||
"title": "FakeRunnableInput",
|
||||
"type": "string",
|
||||
}
|
||||
assert fake_w_fallbacks.output_schema.schema() == {
|
||||
"title": "FakeRunnableOutput",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
def typed_lambda_impl(x: str) -> int:
|
||||
return len(x)
|
||||
|
||||
typed_lambda = RunnableLambda(typed_lambda_impl) # str -> int
|
||||
|
||||
assert typed_lambda.input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "string",
|
||||
}
|
||||
assert typed_lambda.output_schema.schema() == {
|
||||
"title": "RunnableLambdaOutput",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
async def typed_async_lambda_impl(x: str) -> int:
|
||||
return len(x)
|
||||
|
||||
typed_async_lambda: Runnable = RunnableLambda(typed_async_lambda_impl) # str -> int
|
||||
|
||||
assert typed_async_lambda.input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "string",
|
||||
}
|
||||
assert typed_async_lambda.output_schema.schema() == {
|
||||
"title": "RunnableLambdaOutput",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
fake_ret = FakeRetriever() # str -> List[Document]
|
||||
|
||||
assert fake_ret.input_schema.schema() == {
|
||||
"title": "FakeRetrieverInput",
|
||||
"type": "string",
|
||||
}
|
||||
assert fake_ret.output_schema.schema() == {
|
||||
"title": "FakeRetrieverOutput",
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/Document"},
|
||||
"definitions": {
|
||||
"Document": {
|
||||
"title": "Document",
|
||||
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"page_content": {"title": "Page Content", "type": "string"},
|
||||
"metadata": {"title": "Metadata", "type": "object"},
|
||||
},
|
||||
"required": ["page_content"],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
|
||||
|
||||
assert fake_llm.input_schema.schema() == snapshot
|
||||
assert fake_llm.output_schema.schema() == {
|
||||
"title": "FakeListLLMOutput",
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
fake_chat = FakeListChatModel(responses=["a"]) # str -> List[List[str]]
|
||||
|
||||
assert fake_chat.input_schema.schema() == snapshot
|
||||
assert fake_chat.output_schema.schema() == snapshot
|
||||
|
||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||
|
||||
assert prompt.input_schema.schema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"name": {"title": "Name"}},
|
||||
}
|
||||
assert prompt.output_schema.schema() == snapshot
|
||||
|
||||
prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()
|
||||
|
||||
assert prompt_mapper.input_schema.schema() == {
|
||||
"definitions": {
|
||||
"PromptInput": {
|
||||
"properties": {"name": {"title": "Name"}},
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
}
|
||||
},
|
||||
"items": {"$ref": "#/definitions/PromptInput"},
|
||||
"type": "array",
|
||||
"title": "RunnableEachInput",
|
||||
}
|
||||
assert prompt_mapper.output_schema.schema() == snapshot
|
||||
|
||||
list_parser = CommaSeparatedListOutputParser()
|
||||
|
||||
assert list_parser.input_schema.schema() == snapshot
|
||||
assert list_parser.output_schema.schema() == {
|
||||
"title": "CommaSeparatedListOutputParserOutput",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
|
||||
seq = prompt | fake_llm | list_parser
|
||||
|
||||
assert seq.input_schema.schema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"name": {"title": "Name"}},
|
||||
}
|
||||
assert seq.output_schema.schema() == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"title": "CommaSeparatedListOutputParserOutput",
|
||||
}
|
||||
|
||||
router: Runnable = RouterRunnable({})
|
||||
|
||||
assert router.input_schema.schema() == {
|
||||
"title": "RouterRunnableInput",
|
||||
"$ref": "#/definitions/RouterInput",
|
||||
"definitions": {
|
||||
"RouterInput": {
|
||||
"title": "RouterInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {"title": "Key", "type": "string"},
|
||||
"input": {"title": "Input"},
|
||||
},
|
||||
"required": ["key", "input"],
|
||||
}
|
||||
},
|
||||
}
|
||||
assert router.output_schema.schema() == {"title": "RouterRunnableOutput"}
|
||||
|
||||
seq_w_map: Runnable = (
|
||||
prompt
|
||||
| fake_llm
|
||||
| {
|
||||
"original": RunnablePassthrough(input_type=str),
|
||||
"as_list": list_parser,
|
||||
"length": typed_lambda_impl,
|
||||
}
|
||||
)
|
||||
|
||||
assert seq_w_map.input_schema.schema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"name": {"title": "Name"}},
|
||||
}
|
||||
assert seq_w_map.output_schema.schema() == {
|
||||
"title": "RunnableMapOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"original": {"title": "Original", "type": "string"},
|
||||
"length": {"title": "Length", "type": "integer"},
|
||||
"as_list": {
|
||||
"title": "As List",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
json_list_keys_tool = JsonListKeysTool(spec=JsonSpec(dict_={}))
|
||||
|
||||
assert json_list_keys_tool.input_schema.schema() == {
|
||||
"title": "json_spec_list_keysSchema",
|
||||
"type": "object",
|
||||
"properties": {"tool_input": {"title": "Tool Input", "type": "string"}},
|
||||
"required": ["tool_input"],
|
||||
}
|
||||
assert json_list_keys_tool.output_schema.schema() == {
|
||||
"title": "JsonListKeysToolOutput"
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_lambda_schemas() -> None:
|
||||
first_lambda = lambda x: x["hello"] # noqa: E731
|
||||
assert RunnableLambda(first_lambda).input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {"hello": {"title": "Hello"}},
|
||||
}
|
||||
|
||||
second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731
|
||||
assert RunnableLambda(
|
||||
second_lambda, # type: ignore[arg-type]
|
||||
).input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
|
||||
}
|
||||
|
||||
def get_value(input): # type: ignore[no-untyped-def]
|
||||
return input["variable_name"]
|
||||
|
||||
assert RunnableLambda(get_value).input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {"variable_name": {"title": "Variable Name"}},
|
||||
}
|
||||
|
||||
async def aget_value(input): # type: ignore[no-untyped-def]
|
||||
return (input["variable_name"], input.get("another"))
|
||||
|
||||
assert RunnableLambda(aget_value).input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"another": {"title": "Another"},
|
||||
"variable_name": {"title": "Variable Name"},
|
||||
},
|
||||
}
|
||||
|
||||
async def aget_values(input): # type: ignore[no-untyped-def]
|
||||
return {
|
||||
"hello": input["variable_name"],
|
||||
"bye": input["variable_name"],
|
||||
"byebye": input["yo"],
|
||||
}
|
||||
|
||||
assert RunnableLambda(aget_values).input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"variable_name": {"title": "Variable Name"},
|
||||
"yo": {"title": "Yo"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_schema_complex_seq() -> None:
|
||||
prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")
|
||||
prompt2 = ChatPromptTemplate.from_template(
|
||||
"what country is the city {city} in? respond in {language}"
|
||||
)
|
||||
|
||||
model = FakeListChatModel(responses=[""])
|
||||
|
||||
chain1 = prompt1 | model | StrOutputParser()
|
||||
|
||||
chain2: Runnable = (
|
||||
{"city": chain1, "language": itemgetter("language")}
|
||||
| prompt2
|
||||
| model
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
assert chain2.input_schema.schema() == {
|
||||
"title": "RunnableMapInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person": {"title": "Person"},
|
||||
"language": {"title": "Language"},
|
||||
},
|
||||
}
|
||||
|
||||
assert chain2.output_schema.schema() == {
|
||||
"title": "StrOutputParserOutput",
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
|
||||
def test_schema_chains() -> None:
|
||||
model = FakeListChatModel(responses=[""])
|
||||
|
||||
stuff_chain = load_summarize_chain(model)
|
||||
|
||||
assert stuff_chain.input_schema.schema() == {
|
||||
"title": "CombineDocumentsInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_documents": {
|
||||
"title": "Input Documents",
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/Document"},
|
||||
}
|
||||
},
|
||||
"definitions": {
|
||||
"Document": {
|
||||
"title": "Document",
|
||||
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"page_content": {"title": "Page Content", "type": "string"},
|
||||
"metadata": {"title": "Metadata", "type": "object"},
|
||||
},
|
||||
"required": ["page_content"],
|
||||
}
|
||||
},
|
||||
}
|
||||
assert stuff_chain.output_schema.schema() == {
|
||||
"title": "CombineDocumentsOutput",
|
||||
"type": "object",
|
||||
"properties": {"output_text": {"title": "Output Text", "type": "string"}},
|
||||
}
|
||||
|
||||
mapreduce_chain = load_summarize_chain(
|
||||
model, "map_reduce", return_intermediate_steps=True
|
||||
)
|
||||
|
||||
assert mapreduce_chain.input_schema.schema() == {
|
||||
"title": "CombineDocumentsInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_documents": {
|
||||
"title": "Input Documents",
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/Document"},
|
||||
}
|
||||
},
|
||||
"definitions": {
|
||||
"Document": {
|
||||
"title": "Document",
|
||||
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"page_content": {"title": "Page Content", "type": "string"},
|
||||
"metadata": {"title": "Metadata", "type": "object"},
|
||||
},
|
||||
"required": ["page_content"],
|
||||
}
|
||||
},
|
||||
}
|
||||
assert mapreduce_chain.output_schema.schema() == {
|
||||
"title": "MapReduceDocumentsOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"output_text": {"title": "Output Text", "type": "string"},
|
||||
"intermediate_steps": {
|
||||
"title": "Intermediate Steps",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
maprerank_chain = load_qa_chain(model, "map_rerank", metadata_keys=["hello"])
|
||||
|
||||
assert maprerank_chain.input_schema.schema() == {
|
||||
"title": "CombineDocumentsInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_documents": {
|
||||
"title": "Input Documents",
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/Document"},
|
||||
}
|
||||
},
|
||||
"definitions": {
|
||||
"Document": {
|
||||
"title": "Document",
|
||||
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"page_content": {"title": "Page Content", "type": "string"},
|
||||
"metadata": {"title": "Metadata", "type": "object"},
|
||||
},
|
||||
"required": ["page_content"],
|
||||
}
|
||||
},
|
||||
}
|
||||
assert maprerank_chain.output_schema.schema() == {
|
||||
"title": "MapRerankOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"output_text": {"title": "Output Text", "type": "string"},
|
||||
"hello": {"title": "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_config(mocker: MockerFixture) -> None:
|
||||
fake = FakeRunnable()
|
||||
@ -2160,6 +2570,7 @@ def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
|
||||
assert isinstance(body, Runnable)
|
||||
|
||||
assert isinstance(runnable.default, Runnable)
|
||||
assert runnable.input_schema.schema() == {"title": "RunnableBranchInput"}
|
||||
|
||||
|
||||
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user