mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 11:08:55 +00:00
nc/runnable-dynamic-schemas-from-config (#12038)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
85eaa4ccee
commit
85bac75729
@ -61,15 +61,17 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
chains and cannot return as rich of an output as `__call__`.
|
chains and cannot return as rich of an output as `__call__`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"ChainInput", **{k: (Any, None) for k in self.input_keys}
|
"ChainInput", **{k: (Any, None) for k in self.input_keys}
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
def get_output_schema(
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
|
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
|
||||||
|
@ -10,6 +10,7 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.pydantic_v1 import BaseModel, Field, create_model
|
from langchain.pydantic_v1 import BaseModel, Field, create_model
|
||||||
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
|
|
||||||
|
|
||||||
@ -28,15 +29,17 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|||||||
input_key: str = "input_documents" #: :meta private:
|
input_key: str = "input_documents" #: :meta private:
|
||||||
output_key: str = "output_text" #: :meta private:
|
output_key: str = "output_text" #: :meta private:
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
return create_model(
|
return create_model(
|
||||||
"CombineDocumentsInput",
|
"CombineDocumentsInput",
|
||||||
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
|
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
def get_output_schema(
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
return create_model(
|
return create_model(
|
||||||
"CombineDocumentsOutput",
|
"CombineDocumentsOutput",
|
||||||
**{self.output_key: (str, None)}, # type: ignore[call-overload]
|
**{self.output_key: (str, None)}, # type: ignore[call-overload]
|
||||||
@ -167,16 +170,18 @@ class AnalyzeDocumentChain(Chain):
|
|||||||
"""
|
"""
|
||||||
return self.combine_docs_chain.output_keys
|
return self.combine_docs_chain.output_keys
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
return create_model(
|
return create_model(
|
||||||
"AnalyzeDocumentChain",
|
"AnalyzeDocumentChain",
|
||||||
**{self.input_key: (str, None)}, # type: ignore[call-overload]
|
**{self.input_key: (str, None)}, # type: ignore[call-overload]
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
def get_output_schema(
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
return self.combine_docs_chain.output_schema
|
) -> Type[BaseModel]:
|
||||||
|
return self.combine_docs_chain.get_output_schema(config)
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
@ -10,6 +10,7 @@ from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
|
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
|
||||||
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
|
|
||||||
|
|
||||||
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||||
@ -98,8 +99,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return_intermediate_steps: bool = False
|
return_intermediate_steps: bool = False
|
||||||
"""Return the results of the map steps in the output."""
|
"""Return the results of the map steps in the output."""
|
||||||
|
|
||||||
@property
|
def get_output_schema(
|
||||||
def output_schema(self) -> type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
return create_model(
|
return create_model(
|
||||||
"MapReduceDocumentsOutput",
|
"MapReduceDocumentsOutput",
|
||||||
@ -109,7 +111,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
}, # type: ignore[call-overload]
|
}, # type: ignore[call-overload]
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().output_schema
|
return super().get_output_schema(config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
|
||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
@ -10,6 +10,7 @@ from langchain.chains.llm import LLMChain
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.output_parsers.regex import RegexParser
|
from langchain.output_parsers.regex import RegexParser
|
||||||
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
|
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
|
||||||
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
|
|
||||||
|
|
||||||
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||||
@ -77,8 +78,9 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@property
|
def get_output_schema(
|
||||||
def output_schema(self) -> type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
schema: Dict[str, Any] = {
|
schema: Dict[str, Any] = {
|
||||||
self.output_key: (str, None),
|
self.output_key: (str, None),
|
||||||
}
|
}
|
||||||
|
@ -22,6 +22,7 @@ from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
|||||||
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
|
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import BaseMessage
|
||||||
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
from langchain.schema.vectorstore import VectorStore
|
from langchain.schema.vectorstore import VectorStore
|
||||||
|
|
||||||
# Depending on the memory type and configuration, the chat history format may differ.
|
# Depending on the memory type and configuration, the chat history format may differ.
|
||||||
@ -95,8 +96,9 @@ class BaseConversationalRetrievalChain(Chain):
|
|||||||
"""Input keys."""
|
"""Input keys."""
|
||||||
return ["question", "chat_history"]
|
return ["question", "chat_history"]
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
return InputType
|
return InputType
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -45,8 +45,9 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
|||||||
|
|
||||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"PromptInput",
|
"PromptInput",
|
||||||
|
@ -162,6 +162,12 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
"""The type of input this runnable accepts specified as a pydantic model."""
|
||||||
|
return self.get_input_schema()
|
||||||
|
|
||||||
|
def get_input_schema(
|
||||||
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
"""The type of input this runnable accepts specified as a pydantic model."""
|
"""The type of input this runnable accepts specified as a pydantic model."""
|
||||||
root_type = self.InputType
|
root_type = self.InputType
|
||||||
|
|
||||||
@ -174,6 +180,12 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
|
"""The type of output this runnable produces specified as a pydantic model."""
|
||||||
|
return self.get_output_schema()
|
||||||
|
|
||||||
|
def get_output_schema(
|
||||||
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
"""The type of output this runnable produces specified as a pydantic model."""
|
"""The type of output this runnable produces specified as a pydantic model."""
|
||||||
root_type = self.OutputType
|
root_type = self.OutputType
|
||||||
|
|
||||||
@ -1044,13 +1056,15 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
def OutputType(self) -> Type[Output]:
|
def OutputType(self) -> Type[Output]:
|
||||||
return self.last.OutputType
|
return self.last.OutputType
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
return self.first.input_schema
|
) -> Type[BaseModel]:
|
||||||
|
return self.first.get_input_schema(config)
|
||||||
|
|
||||||
@property
|
def get_output_schema(
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
return self.last.output_schema
|
) -> Type[BaseModel]:
|
||||||
|
return self.last.get_output_schema(config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
@ -1551,10 +1565,11 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
|
|
||||||
return Any
|
return Any
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
if all(
|
if all(
|
||||||
s.input_schema.schema().get("type", "object") == "object"
|
s.get_input_schema(config).schema().get("type", "object") == "object"
|
||||||
for s in self.steps.values()
|
for s in self.steps.values()
|
||||||
):
|
):
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
@ -1563,15 +1578,16 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
**{
|
**{
|
||||||
k: (v.annotation, v.default)
|
k: (v.annotation, v.default)
|
||||||
for step in self.steps.values()
|
for step in self.steps.values()
|
||||||
for k, v in step.input_schema.__fields__.items()
|
for k, v in step.get_input_schema(config).__fields__.items()
|
||||||
if k != "__root__"
|
if k != "__root__"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().input_schema
|
return super().get_input_schema(config)
|
||||||
|
|
||||||
@property
|
def get_output_schema(
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"RunnableParallelOutput",
|
"RunnableParallelOutput",
|
||||||
@ -2040,8 +2056,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return Any
|
return Any
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
"""The pydantic schema for the input to this runnable."""
|
"""The pydantic schema for the input to this runnable."""
|
||||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||||
|
|
||||||
@ -2066,7 +2083,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().input_schema
|
return super().get_input_schema(config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def OutputType(self) -> Any:
|
def OutputType(self) -> Any:
|
||||||
@ -2215,12 +2232,13 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
return List[self.bound.InputType] # type: ignore[name-defined]
|
return List[self.bound.InputType] # type: ignore[name-defined]
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
return create_model(
|
return create_model(
|
||||||
"RunnableEachInput",
|
"RunnableEachInput",
|
||||||
__root__=(
|
__root__=(
|
||||||
List[self.bound.input_schema], # type: ignore[name-defined]
|
List[self.bound.get_input_schema(config)], # type: ignore
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -2229,12 +2247,14 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
def OutputType(self) -> Type[List[Output]]:
|
def OutputType(self) -> Type[List[Output]]:
|
||||||
return List[self.bound.OutputType] # type: ignore[name-defined]
|
return List[self.bound.OutputType] # type: ignore[name-defined]
|
||||||
|
|
||||||
@property
|
def get_output_schema(
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
schema = self.bound.get_output_schema(config)
|
||||||
return create_model(
|
return create_model(
|
||||||
"RunnableEachOutput",
|
"RunnableEachOutput",
|
||||||
__root__=(
|
__root__=(
|
||||||
List[self.bound.output_schema], # type: ignore[name-defined]
|
List[schema], # type: ignore
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -2332,13 +2352,15 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
def OutputType(self) -> Type[Output]:
|
def OutputType(self) -> Type[Output]:
|
||||||
return self.bound.OutputType
|
return self.bound.OutputType
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
return self.bound.input_schema
|
) -> Type[BaseModel]:
|
||||||
|
return self.bound.get_input_schema(merge_configs(self.config, config))
|
||||||
|
|
||||||
@property
|
def get_output_schema(
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
return self.bound.output_schema
|
) -> Type[BaseModel]:
|
||||||
|
return self.bound.get_output_schema(merge_configs(self.config, config))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
@ -130,8 +130,9 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||||
return cls.__module__.split(".")[:-1]
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
runnables = (
|
runnables = (
|
||||||
[self.default]
|
[self.default]
|
||||||
+ [r for _, r in self.branches]
|
+ [r for _, r in self.branches]
|
||||||
@ -139,10 +140,10 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for runnable in runnables:
|
for runnable in runnables:
|
||||||
if runnable.input_schema.schema().get("type") is not None:
|
if runnable.get_input_schema(config).schema().get("type") is not None:
|
||||||
return runnable.input_schema
|
return runnable.get_input_schema(config)
|
||||||
|
|
||||||
return super().input_schema
|
return super().get_input_schema(config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
@ -60,13 +60,15 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
def OutputType(self) -> Type[Output]:
|
def OutputType(self) -> Type[Output]:
|
||||||
return self.default.OutputType
|
return self.default.OutputType
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
return self.default.input_schema
|
) -> Type[BaseModel]:
|
||||||
|
return self._prepare(config).get_input_schema(config)
|
||||||
|
|
||||||
@property
|
def get_output_schema(
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
return self.default.output_schema
|
) -> Type[BaseModel]:
|
||||||
|
return self._prepare(config).get_output_schema(config)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _prepare(
|
def _prepare(
|
||||||
|
@ -53,13 +53,15 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
def OutputType(self) -> Type[Output]:
|
def OutputType(self) -> Type[Output]:
|
||||||
return self.runnable.OutputType
|
return self.runnable.OutputType
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
return self.runnable.input_schema
|
) -> Type[BaseModel]:
|
||||||
|
return self.runnable.get_input_schema(config)
|
||||||
|
|
||||||
@property
|
def get_output_schema(
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
return self.runnable.output_schema
|
) -> Type[BaseModel]:
|
||||||
|
return self.runnable.get_output_schema(config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
@ -268,19 +268,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
return cls.__module__.split(".")[:-1]
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
map_input_schema = self.mapper.input_schema
|
) -> Type[BaseModel]:
|
||||||
|
map_input_schema = self.mapper.get_input_schema(config)
|
||||||
if not map_input_schema.__custom_root_type__:
|
if not map_input_schema.__custom_root_type__:
|
||||||
# ie. it's a dict
|
# ie. it's a dict
|
||||||
return map_input_schema
|
return map_input_schema
|
||||||
|
|
||||||
return super().input_schema
|
return super().get_input_schema(config)
|
||||||
|
|
||||||
@property
|
def get_output_schema(
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
map_input_schema = self.mapper.input_schema
|
) -> Type[BaseModel]:
|
||||||
map_output_schema = self.mapper.output_schema
|
map_input_schema = self.mapper.get_input_schema(config)
|
||||||
|
map_output_schema = self.mapper.get_output_schema(config)
|
||||||
if (
|
if (
|
||||||
not map_input_schema.__custom_root_type__
|
not map_input_schema.__custom_root_type__
|
||||||
and not map_output_schema.__custom_root_type__
|
and not map_output_schema.__custom_root_type__
|
||||||
@ -295,7 +297,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().output_schema
|
return super().get_output_schema(config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
|
@ -187,8 +187,9 @@ class ChildTool(BaseTool):
|
|||||||
|
|
||||||
# --- Runnable ---
|
# --- Runnable ---
|
||||||
|
|
||||||
@property
|
def get_input_schema(
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
self, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
"""The tool's input schema."""
|
"""The tool's input schema."""
|
||||||
if self.args_schema is not None:
|
if self.args_schema is not None:
|
||||||
return self.args_schema
|
return self.args_schema
|
||||||
|
@ -800,6 +800,17 @@ def test_configurable_fields() -> None:
|
|||||||
text="Hello, John! John!"
|
text="Hello, John! John!"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert prompt_configurable.with_config(
|
||||||
|
configurable={"prompt_template": "Hello {name} in {lang}"}
|
||||||
|
).input_schema.schema() == {
|
||||||
|
"title": "PromptInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"lang": {"title": "Lang", "type": "string"},
|
||||||
|
"name": {"title": "Name", "type": "string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()
|
chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()
|
||||||
|
|
||||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||||
@ -834,13 +845,27 @@ def test_configurable_fields() -> None:
|
|||||||
assert (
|
assert (
|
||||||
chain_configurable.with_config(
|
chain_configurable.with_config(
|
||||||
configurable={
|
configurable={
|
||||||
"prompt_template": "A very good morning to you, {name}!",
|
"prompt_template": "A very good morning to you, {name} {lang}!",
|
||||||
"llm_responses": ["c"],
|
"llm_responses": ["c"],
|
||||||
}
|
}
|
||||||
).invoke({"name": "John"})
|
).invoke({"name": "John", "lang": "en"})
|
||||||
== "c"
|
== "c"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert chain_configurable.with_config(
|
||||||
|
configurable={
|
||||||
|
"prompt_template": "A very good morning to you, {name} {lang}!",
|
||||||
|
"llm_responses": ["c"],
|
||||||
|
}
|
||||||
|
).input_schema.schema() == {
|
||||||
|
"title": "PromptInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"lang": {"title": "Lang", "type": "string"},
|
||||||
|
"name": {"title": "Name", "type": "string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
chain_with_map_configurable: Runnable = prompt_configurable | {
|
chain_with_map_configurable: Runnable = prompt_configurable | {
|
||||||
"llm1": fake_llm_configurable | StrOutputParser(),
|
"llm1": fake_llm_configurable | StrOutputParser(),
|
||||||
"llm2": fake_llm_configurable | StrOutputParser(),
|
"llm2": fake_llm_configurable | StrOutputParser(),
|
||||||
|
Loading…
Reference in New Issue
Block a user