diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 99ebfd769b0..c39abc8e76a 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -337,7 +337,7 @@ class Runnable(Generic[Input, Output], ABC): return create_model( self.get_name("Output"), - __root__=(root_type, None), + __root__=root_type, ) @property @@ -383,7 +383,7 @@ class Runnable(Generic[Input, Output], ABC): self.get_name("Config"), **({"configurable": (configurable, None)} if configurable else {}), **{ - field_name: (field_type, None) + field_name: field_type for field_name, field_type in RunnableConfig.__annotations__.items() if field_name in [i for i in include if i != "configurable"] }, @@ -577,7 +577,7 @@ class Runnable(Generic[Input, Output], ABC): """ from langchain_core.runnables.passthrough import RunnableAssign - return self | RunnableAssign(RunnableParallel(kwargs)) + return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs)) """ --- Public API --- """ @@ -2395,7 +2395,7 @@ def _seq_input_schema( return first.get_input_schema(config) elif isinstance(first, RunnableAssign): next_input_schema = _seq_input_schema(steps[1:], config) - if not next_input_schema.__custom_root_type__: + if next_input_schema: # it's a dict as expected return create_model( # type: ignore[call-overload] "RunnableSequenceInput", @@ -2422,7 +2422,7 @@ def _seq_output_schema( elif isinstance(last, RunnableAssign): mapper_output_schema = last.mapper.get_output_schema(config) prev_output_schema = _seq_output_schema(steps[:-1], config) - if not prev_output_schema.__custom_root_type__: + if prev_output_schema: # it's a dict as expected return create_model( # type: ignore[call-overload] "RunnableSequenceOutput", @@ -2439,7 +2439,7 @@ def _seq_output_schema( ) elif isinstance(last, RunnablePick): prev_output_schema = _seq_output_schema(steps[:-1], config) - if not prev_output_schema.__custom_root_type__: + if prev_output_schema: # it's a dict as expected if isinstance(last.keys, list): return create_model( # type: ignore[call-overload] @@ -3407,11 +3407,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): Returns: The output schema of the Runnable. """ - # This is correct, but pydantic typings/mypy don't think so. - return create_model( # type: ignore[call-overload] - self.get_name("Output"), - **{k: (v.OutputType, None) for k, v in self.steps__.items()}, - ) + fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()} + return create_model(self.get_name("Output"), **fields) @property def config_specs(self) -> List[ConfigurableFieldSpec]: @@ -4086,7 +4083,6 @@ class RunnableLambda(Runnable[Input, Output]): The input schema for this Runnable. """ func = getattr(self, "func", None) or getattr(self, "afunc") - if isinstance(func, itemgetter): # This is terrible, but afaict it's not possible to access _items # on itemgetter objects, so we have to parse the repr @@ -4094,15 +4090,13 @@ class RunnableLambda(Runnable[Input, Output]): if all( item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items ): + fields = {item[1:-1]: (Any, ...) for item in items} # It's a dict, lol - return create_model( - self.get_name("Input"), - **{item[1:-1]: (Any, None) for item in items}, # type: ignore - ) + return create_model(self.get_name("Input"), **fields) else: return create_model( self.get_name("Input"), - __root__=(List[Any], None), + __root__=List[Any], ) if self.InputType != Any: @@ -4111,7 +4105,7 @@ class RunnableLambda(Runnable[Input, Output]): if dict_keys := get_function_first_arg_dict_keys(func): return create_model( self.get_name("Input"), - **{key: (Any, None) for key in dict_keys}, # type: ignore + **{key: Any for key in dict_keys}, # type: ignore ) return super().get_input_schema(config) @@ -4664,13 +4658,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: schema = self.bound.get_output_schema(config) - return create_model( - self.get_name("Output"), - __root__=( - List[schema], # type: ignore - None, - ), - ) + return create_model(self.get_name("Output"), __root__=List[schema]) @property def config_specs(self) -> List[ConfigurableFieldSpec]: diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 446718377fe..db02df673b6 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -9,9 +9,6 @@ import textwrap from functools import lru_cache from inspect import signature from itertools import groupby -from pydantic import BaseModel, ConfigDict, RootModel -from pydantic import create_model as _create_model_base -from pydantic.json_schema import DEFAULT_REF_TEMPLATE from typing import ( Any, AsyncIterable, @@ -32,6 +29,10 @@ from typing import ( TypeVar, Union, ) + +from pydantic import BaseModel, ConfigDict, RootModel +from pydantic import create_model as _create_model_base +from pydantic.json_schema import DEFAULT_REF_TEMPLATE from typing_extensions import TypeGuard from langchain_core.runnables.schema import StreamEvent @@ -697,7 +698,7 @@ _SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True) NO_DEFAULT = object() -def create_base_class(name: str, type_, default_=NO_DEFAULT) -> Type: +def create_base_class(name: str, type_, default_=NO_DEFAULT) -> Type[BaseModel]: if default_ is NO_DEFAULT: class FixedNameRootModel(RootModel): diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index b288cffcd7c..9ba87fa5b37 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -38,7 +38,6 @@ from langchain_core.language_models import ( from langchain_core.load import dumpd, dumps from langchain_core.load.load import loads from langchain_core.messages import ( - AIMessage, AIMessageChunk, HumanMessage, SystemMessage, @@ -620,12 +619,25 @@ def test_with_types_with_type_generics() -> None: ) +def test_schema_with_itemgetter() -> None: + """Test runnable with itemgetter.""" + foo = RunnableLambda(itemgetter("hello")) + assert foo.input_schema.schema() == { + "properties": {"hello": {"title": "Hello"}}, + "required": ["hello"], + "title": "RunnableLambdaInput", + "type": "object", + } + prompt = ChatPromptTemplate.from_template("what is {language}?") + chain: Runnable = {"language": itemgetter("language")} | prompt + assert chain.input_schema.schema() == {} + + def test_schema_complex_seq() -> None: prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?") prompt2 = ChatPromptTemplate.from_template( "what country is the city {city} in? respond in {language}" ) - model = FakeListChatModel(responses=[""]) chain1: Runnable = RunnableSequence( @@ -648,8 +660,8 @@ def test_schema_complex_seq() -> None: "person": {"title": "Person", "type": "string"}, "language": {"title": "Language"}, }, + "required": ["person", "language"], } - assert chain2.output_schema.schema() == { "title": "StrOutputParserOutput", "type": "string", @@ -3072,6 +3084,43 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N assert len(map_run.child_runs) == 3 +def test_schemas_2(): + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + + chat_res = "i'm a chatbot" + # sleep to better simulate a real stream + + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + + chain: Runnable = prompt | { + "llm": llm, + "passthrough": RunnablePassthrough(), + } + chain_pick_one = chain.pick("llm") + + assert chain_pick_one.output_schema.schema() == { + "title": "RunnableSequenceOutput", + "type": "string", + } + + +def test_foo(): + """Test create model.""" + from pydantic import RootModel, create_model + + class Foo(RootModel): + pass + + meow = Foo[str] + model = create_model("meow", **{"llm": (meow, ...)}) + pass + + def test_map_stream() -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") @@ -3164,7 +3213,7 @@ def test_map_stream() -> None: assert streamed_chunks[0] in [ {"llm": "i"}, - {"chat": AIMessageChunk(content="i")}, + {"chat": _AnyIdAIMessageChunk(content="i")}, ] assert len(streamed_chunks) == len(llm_res) + len(chat_res)