nc/runnable-dynamic-schemas-from-config (#12038)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-10-19 19:34:35 +01:00 committed by GitHub
parent 85eaa4ccee
commit 85bac75729
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 151 additions and 82 deletions

View File

@ -61,15 +61,17 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
chains and cannot return as rich of an output as `__call__`. chains and cannot return as rich of an output as `__call__`.
""" """
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"ChainInput", **{k: (Any, None) for k in self.input_keys} "ChainInput", **{k: (Any, None) for k in self.input_keys}
) )
@property def get_output_schema(
def output_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"ChainOutput", **{k: (Any, None) for k in self.output_keys} "ChainOutput", **{k: (Any, None) for k in self.output_keys}

View File

@ -10,6 +10,7 @@ from langchain.callbacks.manager import (
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.pydantic_v1 import BaseModel, Field, create_model from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain.schema.runnable.config import RunnableConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
@ -28,15 +29,17 @@ class BaseCombineDocumentsChain(Chain, ABC):
input_key: str = "input_documents" #: :meta private: input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private: output_key: str = "output_text" #: :meta private:
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return create_model( return create_model(
"CombineDocumentsInput", "CombineDocumentsInput",
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload] **{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
) )
@property def get_output_schema(
def output_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return create_model( return create_model(
"CombineDocumentsOutput", "CombineDocumentsOutput",
**{self.output_key: (str, None)}, # type: ignore[call-overload] **{self.output_key: (str, None)}, # type: ignore[call-overload]
@ -167,16 +170,18 @@ class AnalyzeDocumentChain(Chain):
""" """
return self.combine_docs_chain.output_keys return self.combine_docs_chain.output_keys
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return create_model( return create_model(
"AnalyzeDocumentChain", "AnalyzeDocumentChain",
**{self.input_key: (str, None)}, # type: ignore[call-overload] **{self.input_key: (str, None)}, # type: ignore[call-overload]
) )
@property def get_output_schema(
def output_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
return self.combine_docs_chain.output_schema ) -> Type[BaseModel]:
return self.combine_docs_chain.get_output_schema(config)
def _call( def _call(
self, self,

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Type
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -10,6 +10,7 @@ from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
from langchain.schema.runnable.config import RunnableConfig
class MapReduceDocumentsChain(BaseCombineDocumentsChain): class MapReduceDocumentsChain(BaseCombineDocumentsChain):
@ -98,8 +99,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return_intermediate_steps: bool = False return_intermediate_steps: bool = False
"""Return the results of the map steps in the output.""" """Return the results of the map steps in the output."""
@property def get_output_schema(
def output_schema(self) -> type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
if self.return_intermediate_steps: if self.return_intermediate_steps:
return create_model( return create_model(
"MapReduceDocumentsOutput", "MapReduceDocumentsOutput",
@ -109,7 +111,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
}, # type: ignore[call-overload] }, # type: ignore[call-overload]
) )
return super().output_schema return super().get_output_schema(config)
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -10,6 +10,7 @@ from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser from langchain.output_parsers.regex import RegexParser
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
from langchain.schema.runnable.config import RunnableConfig
class MapRerankDocumentsChain(BaseCombineDocumentsChain): class MapRerankDocumentsChain(BaseCombineDocumentsChain):
@ -77,8 +78,9 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property def get_output_schema(
def output_schema(self) -> type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
schema: Dict[str, Any] = { schema: Dict[str, Any] = {
self.output_key: (str, None), self.output_key: (str, None),
} }

View File

@ -22,6 +22,7 @@ from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain.schema import BasePromptTemplate, BaseRetriever, Document from langchain.schema import BasePromptTemplate, BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage from langchain.schema.messages import BaseMessage
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.vectorstore import VectorStore from langchain.schema.vectorstore import VectorStore
# Depending on the memory type and configuration, the chat history format may differ. # Depending on the memory type and configuration, the chat history format may differ.
@ -95,8 +96,9 @@ class BaseConversationalRetrievalChain(Chain):
"""Input keys.""" """Input keys."""
return ["question", "chat_history"] return ["question", "chat_history"]
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return InputType return InputType
@property @property

View File

@ -45,8 +45,9 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
return Union[StringPromptValue, ChatPromptValueConcrete] return Union[StringPromptValue, ChatPromptValueConcrete]
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"PromptInput", "PromptInput",

View File

@ -162,6 +162,12 @@ class Runnable(Generic[Input, Output], ABC):
@property @property
def input_schema(self) -> Type[BaseModel]: def input_schema(self) -> Type[BaseModel]:
"""The type of input this runnable accepts specified as a pydantic model."""
return self.get_input_schema()
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The type of input this runnable accepts specified as a pydantic model.""" """The type of input this runnable accepts specified as a pydantic model."""
root_type = self.InputType root_type = self.InputType
@ -174,6 +180,12 @@ class Runnable(Generic[Input, Output], ABC):
@property @property
def output_schema(self) -> Type[BaseModel]: def output_schema(self) -> Type[BaseModel]:
"""The type of output this runnable produces specified as a pydantic model."""
return self.get_output_schema()
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The type of output this runnable produces specified as a pydantic model.""" """The type of output this runnable produces specified as a pydantic model."""
root_type = self.OutputType root_type = self.OutputType
@ -1044,13 +1056,15 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def OutputType(self) -> Type[Output]: def OutputType(self) -> Type[Output]:
return self.last.OutputType return self.last.OutputType
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
return self.first.input_schema ) -> Type[BaseModel]:
return self.first.get_input_schema(config)
@property def get_output_schema(
def output_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
return self.last.output_schema ) -> Type[BaseModel]:
return self.last.get_output_schema(config)
@property @property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
@ -1551,10 +1565,11 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
return Any return Any
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
if all( if all(
s.input_schema.schema().get("type", "object") == "object" s.get_input_schema(config).schema().get("type", "object") == "object"
for s in self.steps.values() for s in self.steps.values()
): ):
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
@ -1563,15 +1578,16 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
**{ **{
k: (v.annotation, v.default) k: (v.annotation, v.default)
for step in self.steps.values() for step in self.steps.values()
for k, v in step.input_schema.__fields__.items() for k, v in step.get_input_schema(config).__fields__.items()
if k != "__root__" if k != "__root__"
}, },
) )
return super().input_schema return super().get_input_schema(config)
@property def get_output_schema(
def output_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"RunnableParallelOutput", "RunnableParallelOutput",
@ -2040,8 +2056,9 @@ class RunnableLambda(Runnable[Input, Output]):
except ValueError: except ValueError:
return Any return Any
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The pydantic schema for the input to this runnable.""" """The pydantic schema for the input to this runnable."""
func = getattr(self, "func", None) or getattr(self, "afunc") func = getattr(self, "func", None) or getattr(self, "afunc")
@ -2066,7 +2083,7 @@ class RunnableLambda(Runnable[Input, Output]):
**{key: (Any, None) for key in dict_keys}, # type: ignore **{key: (Any, None) for key in dict_keys}, # type: ignore
) )
return super().input_schema return super().get_input_schema(config)
@property @property
def OutputType(self) -> Any: def OutputType(self) -> Any:
@ -2215,12 +2232,13 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
def InputType(self) -> Any: def InputType(self) -> Any:
return List[self.bound.InputType] # type: ignore[name-defined] return List[self.bound.InputType] # type: ignore[name-defined]
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return create_model( return create_model(
"RunnableEachInput", "RunnableEachInput",
__root__=( __root__=(
List[self.bound.input_schema], # type: ignore[name-defined] List[self.bound.get_input_schema(config)], # type: ignore
None, None,
), ),
) )
@ -2229,12 +2247,14 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
def OutputType(self) -> Type[List[Output]]: def OutputType(self) -> Type[List[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined] return List[self.bound.OutputType] # type: ignore[name-defined]
@property def get_output_schema(
def output_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
schema = self.bound.get_output_schema(config)
return create_model( return create_model(
"RunnableEachOutput", "RunnableEachOutput",
__root__=( __root__=(
List[self.bound.output_schema], # type: ignore[name-defined] List[schema], # type: ignore
None, None,
), ),
) )
@ -2332,13 +2352,15 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
def OutputType(self) -> Type[Output]: def OutputType(self) -> Type[Output]:
return self.bound.OutputType return self.bound.OutputType
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
return self.bound.input_schema ) -> Type[BaseModel]:
return self.bound.get_input_schema(merge_configs(self.config, config))
@property def get_output_schema(
def output_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
return self.bound.output_schema ) -> Type[BaseModel]:
return self.bound.get_output_schema(merge_configs(self.config, config))
@property @property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:

View File

@ -130,8 +130,9 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
"""The namespace of a RunnableBranch is the namespace of its default branch.""" """The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1] return cls.__module__.split(".")[:-1]
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
runnables = ( runnables = (
[self.default] [self.default]
+ [r for _, r in self.branches] + [r for _, r in self.branches]
@ -139,10 +140,10 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
) )
for runnable in runnables: for runnable in runnables:
if runnable.input_schema.schema().get("type") is not None: if runnable.get_input_schema(config).schema().get("type") is not None:
return runnable.input_schema return runnable.get_input_schema(config)
return super().input_schema return super().get_input_schema(config)
@property @property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:

View File

@ -60,13 +60,15 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def OutputType(self) -> Type[Output]: def OutputType(self) -> Type[Output]:
return self.default.OutputType return self.default.OutputType
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
return self.default.input_schema ) -> Type[BaseModel]:
return self._prepare(config).get_input_schema(config)
@property def get_output_schema(
def output_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
return self.default.output_schema ) -> Type[BaseModel]:
return self._prepare(config).get_output_schema(config)
@abstractmethod @abstractmethod
def _prepare( def _prepare(

View File

@ -53,13 +53,15 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
def OutputType(self) -> Type[Output]: def OutputType(self) -> Type[Output]:
return self.runnable.OutputType return self.runnable.OutputType
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
return self.runnable.input_schema ) -> Type[BaseModel]:
return self.runnable.get_input_schema(config)
@property def get_output_schema(
def output_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
return self.runnable.output_schema ) -> Type[BaseModel]:
return self.runnable.get_output_schema(config)
@property @property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:

View File

@ -268,19 +268,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] return cls.__module__.split(".")[:-1]
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
map_input_schema = self.mapper.input_schema ) -> Type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
if not map_input_schema.__custom_root_type__: if not map_input_schema.__custom_root_type__:
# ie. it's a dict # ie. it's a dict
return map_input_schema return map_input_schema
return super().input_schema return super().get_input_schema(config)
@property def get_output_schema(
def output_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
map_input_schema = self.mapper.input_schema ) -> Type[BaseModel]:
map_output_schema = self.mapper.output_schema map_input_schema = self.mapper.get_input_schema(config)
map_output_schema = self.mapper.get_output_schema(config)
if ( if (
not map_input_schema.__custom_root_type__ not map_input_schema.__custom_root_type__
and not map_output_schema.__custom_root_type__ and not map_output_schema.__custom_root_type__
@ -295,7 +297,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
}, },
) )
return super().output_schema return super().get_output_schema(config)
@property @property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:

View File

@ -187,8 +187,9 @@ class ChildTool(BaseTool):
# --- Runnable --- # --- Runnable ---
@property def get_input_schema(
def input_schema(self) -> Type[BaseModel]: self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The tool's input schema.""" """The tool's input schema."""
if self.args_schema is not None: if self.args_schema is not None:
return self.args_schema return self.args_schema

View File

@ -800,6 +800,17 @@ def test_configurable_fields() -> None:
text="Hello, John! John!" text="Hello, John! John!"
) )
assert prompt_configurable.with_config(
configurable={"prompt_template": "Hello {name} in {lang}"}
).input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {
"lang": {"title": "Lang", "type": "string"},
"name": {"title": "Name", "type": "string"},
},
}
chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser() chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()
assert chain_configurable.invoke({"name": "John"}) == "a" assert chain_configurable.invoke({"name": "John"}) == "a"
@ -834,13 +845,27 @@ def test_configurable_fields() -> None:
assert ( assert (
chain_configurable.with_config( chain_configurable.with_config(
configurable={ configurable={
"prompt_template": "A very good morning to you, {name}!", "prompt_template": "A very good morning to you, {name} {lang}!",
"llm_responses": ["c"], "llm_responses": ["c"],
} }
).invoke({"name": "John"}) ).invoke({"name": "John", "lang": "en"})
== "c" == "c"
) )
assert chain_configurable.with_config(
configurable={
"prompt_template": "A very good morning to you, {name} {lang}!",
"llm_responses": ["c"],
}
).input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {
"lang": {"title": "Lang", "type": "string"},
"name": {"title": "Name", "type": "string"},
},
}
chain_with_map_configurable: Runnable = prompt_configurable | { chain_with_map_configurable: Runnable = prompt_configurable | {
"llm1": fake_llm_configurable | StrOutputParser(), "llm1": fake_llm_configurable | StrOutputParser(),
"llm2": fake_llm_configurable | StrOutputParser(), "llm2": fake_llm_configurable | StrOutputParser(),