STASHED CHANGES

This commit is contained in:
Eugene Yurtsev
2024-07-30 11:16:36 -04:00
parent 3a76e0f2ae
commit 0482a810bc
3 changed files with 71 additions and 33 deletions

View File

@@ -337,7 +337,7 @@ class Runnable(Generic[Input, Output], ABC):
return create_model(
self.get_name("Output"),
__root__=(root_type, None),
__root__=root_type,
)
@property
@@ -383,7 +383,7 @@ class Runnable(Generic[Input, Output], ABC):
self.get_name("Config"),
**({"configurable": (configurable, None)} if configurable else {}),
**{
field_name: (field_type, None)
field_name: field_type
for field_name, field_type in RunnableConfig.__annotations__.items()
if field_name in [i for i in include if i != "configurable"]
},
@@ -577,7 +577,7 @@ class Runnable(Generic[Input, Output], ABC):
"""
from langchain_core.runnables.passthrough import RunnableAssign
return self | RunnableAssign(RunnableParallel(kwargs))
return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
""" --- Public API --- """
@@ -2395,7 +2395,7 @@ def _seq_input_schema(
return first.get_input_schema(config)
elif isinstance(first, RunnableAssign):
next_input_schema = _seq_input_schema(steps[1:], config)
if not next_input_schema.__custom_root_type__:
if next_input_schema:
# it's a dict as expected
return create_model( # type: ignore[call-overload]
"RunnableSequenceInput",
@@ -2422,7 +2422,7 @@ def _seq_output_schema(
elif isinstance(last, RunnableAssign):
mapper_output_schema = last.mapper.get_output_schema(config)
prev_output_schema = _seq_output_schema(steps[:-1], config)
if not prev_output_schema.__custom_root_type__:
if prev_output_schema:
# it's a dict as expected
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
@@ -2439,7 +2439,7 @@ def _seq_output_schema(
)
elif isinstance(last, RunnablePick):
prev_output_schema = _seq_output_schema(steps[:-1], config)
if not prev_output_schema.__custom_root_type__:
if prev_output_schema:
# it's a dict as expected
if isinstance(last.keys, list):
return create_model( # type: ignore[call-overload]
@@ -3407,11 +3407,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
Returns:
The output schema of the Runnable.
"""
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
self.get_name("Output"),
**{k: (v.OutputType, None) for k, v in self.steps__.items()},
)
fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()}
return create_model(self.get_name("Output"), **fields)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
@@ -4086,7 +4083,6 @@ class RunnableLambda(Runnable[Input, Output]):
The input schema for this Runnable.
"""
func = getattr(self, "func", None) or getattr(self, "afunc")
if isinstance(func, itemgetter):
# This is terrible, but afaict it's not possible to access _items
# on itemgetter objects, so we have to parse the repr
@@ -4094,15 +4090,13 @@ class RunnableLambda(Runnable[Input, Output]):
if all(
item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items
):
fields = {item[1:-1]: (Any, ...) for item in items}
# It's a dict, lol
return create_model(
self.get_name("Input"),
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
)
return create_model(self.get_name("Input"), **fields)
else:
return create_model(
self.get_name("Input"),
__root__=(List[Any], None),
__root__=List[Any],
)
if self.InputType != Any:
@@ -4111,7 +4105,7 @@ class RunnableLambda(Runnable[Input, Output]):
if dict_keys := get_function_first_arg_dict_keys(func):
return create_model(
self.get_name("Input"),
**{key: (Any, None) for key in dict_keys}, # type: ignore
**{key: Any for key in dict_keys}, # type: ignore
)
return super().get_input_schema(config)
@@ -4664,13 +4658,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
schema = self.bound.get_output_schema(config)
return create_model(
self.get_name("Output"),
__root__=(
List[schema], # type: ignore
None,
),
)
return create_model(self.get_name("Output"), __root__=List[schema])
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:

View File

@@ -9,9 +9,6 @@ import textwrap
from functools import lru_cache
from inspect import signature
from itertools import groupby
from pydantic import BaseModel, ConfigDict, RootModel
from pydantic import create_model as _create_model_base
from pydantic.json_schema import DEFAULT_REF_TEMPLATE
from typing import (
Any,
AsyncIterable,
@@ -32,6 +29,10 @@ from typing import (
TypeVar,
Union,
)
from pydantic import BaseModel, ConfigDict, RootModel
from pydantic import create_model as _create_model_base
from pydantic.json_schema import DEFAULT_REF_TEMPLATE
from typing_extensions import TypeGuard
from langchain_core.runnables.schema import StreamEvent
@@ -697,7 +698,7 @@ _SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True)
NO_DEFAULT = object()
def create_base_class(name: str, type_, default_=NO_DEFAULT) -> Type:
def create_base_class(name: str, type_, default_=NO_DEFAULT) -> Type[BaseModel]:
if default_ is NO_DEFAULT:
class FixedNameRootModel(RootModel):

View File

@@ -38,7 +38,6 @@ from langchain_core.language_models import (
from langchain_core.load import dumpd, dumps
from langchain_core.load.load import loads
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
@@ -620,12 +619,25 @@ def test_with_types_with_type_generics() -> None:
)
def test_schema_with_itemgetter() -> None:
"""Test runnable with itemgetter."""
foo = RunnableLambda(itemgetter("hello"))
assert foo.input_schema.schema() == {
"properties": {"hello": {"title": "Hello"}},
"required": ["hello"],
"title": "RunnableLambdaInput",
"type": "object",
}
prompt = ChatPromptTemplate.from_template("what is {language}?")
chain: Runnable = {"language": itemgetter("language")} | prompt
assert chain.input_schema.schema() == {}
def test_schema_complex_seq() -> None:
prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")
prompt2 = ChatPromptTemplate.from_template(
"what country is the city {city} in? respond in {language}"
)
model = FakeListChatModel(responses=[""])
chain1: Runnable = RunnableSequence(
@@ -648,8 +660,8 @@ def test_schema_complex_seq() -> None:
"person": {"title": "Person", "type": "string"},
"language": {"title": "Language"},
},
"required": ["person", "language"],
}
assert chain2.output_schema.schema() == {
"title": "StrOutputParserOutput",
"type": "string",
@@ -3072,6 +3084,43 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
assert len(map_run.child_runs) == 3
def test_schemas_2():
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat_res = "i'm a chatbot"
# sleep to better simulate a real stream
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
chain: Runnable = prompt | {
"llm": llm,
"passthrough": RunnablePassthrough(),
}
chain_pick_one = chain.pick("llm")
assert chain_pick_one.output_schema.schema() == {
"title": "RunnableSequenceOutput",
"type": "string",
}
def test_foo():
"""Test create model."""
from pydantic import RootModel, create_model
class Foo(RootModel):
pass
meow = Foo[str]
model = create_model("meow", **{"llm": (meow, ...)})
pass
def test_map_stream() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@@ -3164,7 +3213,7 @@ def test_map_stream() -> None:
assert streamed_chunks[0] in [
{"llm": "i"},
{"chat": AIMessageChunk(content="i")},
{"chat": _AnyIdAIMessageChunk(content="i")},
]
assert len(streamed_chunks) == len(llm_res) + len(chat_res)