mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 02:33:19 +00:00
nc/runnable-dynamic-schemas-from-config
This commit is contained in:
parent
d392e030be
commit
a46eef64a7
@ -61,15 +61,17 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
chains and cannot return as rich of an output as `__call__`.
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"ChainInput", **{k: (Any, None) for k in self.input_keys}
|
||||
)
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
|
||||
|
@ -10,6 +10,7 @@ from langchain.callbacks.manager import (
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
@ -28,15 +29,17 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
input_key: str = "input_documents" #: :meta private:
|
||||
output_key: str = "output_text" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return create_model(
|
||||
"CombineDocumentsInput",
|
||||
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
|
||||
)
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return create_model(
|
||||
"CombineDocumentsOutput",
|
||||
**{self.output_key: (str, None)}, # type: ignore[call-overload]
|
||||
@ -167,16 +170,18 @@ class AnalyzeDocumentChain(Chain):
|
||||
"""
|
||||
return self.combine_docs_chain.output_keys
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return create_model(
|
||||
"AnalyzeDocumentChain",
|
||||
**{self.input_key: (str, None)}, # type: ignore[call-overload]
|
||||
)
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.combine_docs_chain.output_schema
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.combine_docs_chain.get_output_schema(config)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
@ -10,6 +10,7 @@ from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
|
||||
|
||||
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
@ -98,8 +99,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
return_intermediate_steps: bool = False
|
||||
"""Return the results of the map steps in the output."""
|
||||
|
||||
@property
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
if self.return_intermediate_steps:
|
||||
return create_model(
|
||||
"MapReduceDocumentsOutput",
|
||||
@ -109,7 +111,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
}, # type: ignore[call-overload]
|
||||
)
|
||||
|
||||
return super().output_schema
|
||||
return super().get_output_schema(config)
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
@ -10,6 +10,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
|
||||
|
||||
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
@ -77,8 +78,9 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
schema: Dict[str, Any] = {
|
||||
self.output_key: (str, None),
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.schema.vectorstore import VectorStore
|
||||
|
||||
# Depending on the memory type and configuration, the chat history format may differ.
|
||||
@ -95,8 +96,9 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
"""Input keys."""
|
||||
return ["question", "chat_history"]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return InputType
|
||||
|
||||
@property
|
||||
|
@ -45,8 +45,9 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
|
||||
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"PromptInput",
|
||||
|
@ -162,6 +162,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
"""The type of input this runnable accepts specified as a pydantic model."""
|
||||
return self.get_input_schema()
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
"""The type of input this runnable accepts specified as a pydantic model."""
|
||||
root_type = self.InputType
|
||||
|
||||
@ -174,6 +180,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
"""The type of output this runnable produces specified as a pydantic model."""
|
||||
return self.get_output_schema()
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
"""The type of output this runnable produces specified as a pydantic model."""
|
||||
root_type = self.OutputType
|
||||
|
||||
@ -1044,13 +1056,15 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.last.OutputType
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return self.first.input_schema
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.first.get_input_schema(config)
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.last.output_schema
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.last.get_output_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
@ -1551,10 +1565,11 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
|
||||
return Any
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
if all(
|
||||
s.input_schema.schema().get("type", "object") == "object"
|
||||
s.get_input_schema(config).schema().get("type", "object") == "object"
|
||||
for s in self.steps.values()
|
||||
):
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
@ -1563,15 +1578,16 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for step in self.steps.values()
|
||||
for k, v in step.input_schema.__fields__.items()
|
||||
for k, v in step.get_input_schema(config).__fields__.items()
|
||||
if k != "__root__"
|
||||
},
|
||||
)
|
||||
|
||||
return super().input_schema
|
||||
return super().get_input_schema(config)
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableParallelOutput",
|
||||
@ -2040,8 +2056,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
except ValueError:
|
||||
return Any
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
"""The pydantic schema for the input to this runnable."""
|
||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||
|
||||
@ -2066,7 +2083,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
||||
)
|
||||
|
||||
return super().input_schema
|
||||
return super().get_input_schema(config)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
@ -2215,12 +2232,13 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||
def InputType(self) -> Any:
|
||||
return List[self.bound.InputType] # type: ignore[name-defined]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return create_model(
|
||||
"RunnableEachInput",
|
||||
__root__=(
|
||||
List[self.bound.input_schema], # type: ignore[name-defined]
|
||||
List[self.bound.get_input_schema(config)], # type: ignore
|
||||
None,
|
||||
),
|
||||
)
|
||||
@ -2229,12 +2247,14 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||
def OutputType(self) -> Type[List[Output]]:
|
||||
return List[self.bound.OutputType] # type: ignore[name-defined]
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
schema = self.bound.get_output_schema(config)
|
||||
return create_model(
|
||||
"RunnableEachOutput",
|
||||
__root__=(
|
||||
List[self.bound.output_schema], # type: ignore[name-defined]
|
||||
List[schema], # type: ignore
|
||||
None,
|
||||
),
|
||||
)
|
||||
@ -2332,13 +2352,15 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.bound.OutputType
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return self.bound.input_schema
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.bound.get_input_schema(merge_configs(self.config, config))
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.bound.output_schema
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.bound.get_output_schema(merge_configs(self.config, config))
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
|
@ -130,8 +130,9 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
runnables = (
|
||||
[self.default]
|
||||
+ [r for _, r in self.branches]
|
||||
@ -139,10 +140,10 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
|
||||
for runnable in runnables:
|
||||
if runnable.input_schema.schema().get("type") is not None:
|
||||
return runnable.input_schema
|
||||
if runnable.get_input_schema(config).schema().get("type") is not None:
|
||||
return runnable.get_input_schema(config)
|
||||
|
||||
return super().input_schema
|
||||
return super().get_input_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
|
@ -60,13 +60,15 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.default.OutputType
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return self.default.input_schema
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self._prepare(config).get_input_schema(config)
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.default.output_schema
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self._prepare(config).get_output_schema(config)
|
||||
|
||||
@abstractmethod
|
||||
def _prepare(
|
||||
|
@ -53,13 +53,15 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.runnable.OutputType
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return self.runnable.input_schema
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.runnable.get_input_schema(config)
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.runnable.output_schema
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.runnable.get_output_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
|
@ -268,19 +268,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.input_schema
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
if not map_input_schema.__custom_root_type__:
|
||||
# ie. it's a dict
|
||||
return map_input_schema
|
||||
|
||||
return super().input_schema
|
||||
return super().get_input_schema(config)
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.input_schema
|
||||
map_output_schema = self.mapper.output_schema
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
map_input_schema = self.mapper.get_input_schema(config)
|
||||
map_output_schema = self.mapper.get_output_schema(config)
|
||||
if (
|
||||
not map_input_schema.__custom_root_type__
|
||||
and not map_output_schema.__custom_root_type__
|
||||
@ -295,7 +297,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
},
|
||||
)
|
||||
|
||||
return super().output_schema
|
||||
return super().get_output_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
|
@ -187,8 +187,9 @@ class ChildTool(BaseTool):
|
||||
|
||||
# --- Runnable ---
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
"""The tool's input schema."""
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema
|
||||
|
@ -800,6 +800,17 @@ def test_configurable_fields() -> None:
|
||||
text="Hello, John! John!"
|
||||
)
|
||||
|
||||
assert prompt_configurable.with_config(
|
||||
configurable={"prompt_template": "Hello {name} in {lang}"}
|
||||
).input_schema.schema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"lang": {"title": "Lang", "type": "string"},
|
||||
"name": {"title": "Name", "type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()
|
||||
|
||||
assert chain_configurable.invoke({"name": "John"}) == "a"
|
||||
@ -834,13 +845,27 @@ def test_configurable_fields() -> None:
|
||||
assert (
|
||||
chain_configurable.with_config(
|
||||
configurable={
|
||||
"prompt_template": "A very good morning to you, {name}!",
|
||||
"prompt_template": "A very good morning to you, {name} {lang}!",
|
||||
"llm_responses": ["c"],
|
||||
}
|
||||
).invoke({"name": "John"})
|
||||
).invoke({"name": "John", "lang": "en"})
|
||||
== "c"
|
||||
)
|
||||
|
||||
assert chain_configurable.with_config(
|
||||
configurable={
|
||||
"prompt_template": "A very good morning to you, {name} {lang}!",
|
||||
"llm_responses": ["c"],
|
||||
}
|
||||
).input_schema.schema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"lang": {"title": "Lang", "type": "string"},
|
||||
"name": {"title": "Name", "type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
chain_with_map_configurable: Runnable = prompt_configurable | {
|
||||
"llm1": fake_llm_configurable | StrOutputParser(),
|
||||
"llm2": fake_llm_configurable | StrOutputParser(),
|
||||
|
Loading…
Reference in New Issue
Block a user