mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
STASHED CHANGES
This commit is contained in:
@@ -337,7 +337,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
return create_model(
|
||||
self.get_name("Output"),
|
||||
__root__=(root_type, None),
|
||||
__root__=root_type,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -383,7 +383,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
self.get_name("Config"),
|
||||
**({"configurable": (configurable, None)} if configurable else {}),
|
||||
**{
|
||||
field_name: (field_type, None)
|
||||
field_name: field_type
|
||||
for field_name, field_type in RunnableConfig.__annotations__.items()
|
||||
if field_name in [i for i in include if i != "configurable"]
|
||||
},
|
||||
@@ -577,7 +577,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
from langchain_core.runnables.passthrough import RunnableAssign
|
||||
|
||||
return self | RunnableAssign(RunnableParallel(kwargs))
|
||||
return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs))
|
||||
|
||||
""" --- Public API --- """
|
||||
|
||||
@@ -2395,7 +2395,7 @@ def _seq_input_schema(
|
||||
return first.get_input_schema(config)
|
||||
elif isinstance(first, RunnableAssign):
|
||||
next_input_schema = _seq_input_schema(steps[1:], config)
|
||||
if not next_input_schema.__custom_root_type__:
|
||||
if next_input_schema:
|
||||
# it's a dict as expected
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceInput",
|
||||
@@ -2422,7 +2422,7 @@ def _seq_output_schema(
|
||||
elif isinstance(last, RunnableAssign):
|
||||
mapper_output_schema = last.mapper.get_output_schema(config)
|
||||
prev_output_schema = _seq_output_schema(steps[:-1], config)
|
||||
if not prev_output_schema.__custom_root_type__:
|
||||
if prev_output_schema:
|
||||
# it's a dict as expected
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceOutput",
|
||||
@@ -2439,7 +2439,7 @@ def _seq_output_schema(
|
||||
)
|
||||
elif isinstance(last, RunnablePick):
|
||||
prev_output_schema = _seq_output_schema(steps[:-1], config)
|
||||
if not prev_output_schema.__custom_root_type__:
|
||||
if prev_output_schema:
|
||||
# it's a dict as expected
|
||||
if isinstance(last.keys, list):
|
||||
return create_model( # type: ignore[call-overload]
|
||||
@@ -3407,11 +3407,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
Returns:
|
||||
The output schema of the Runnable.
|
||||
"""
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
self.get_name("Output"),
|
||||
**{k: (v.OutputType, None) for k, v in self.steps__.items()},
|
||||
)
|
||||
fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()}
|
||||
return create_model(self.get_name("Output"), **fields)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
@@ -4086,7 +4083,6 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
The input schema for this Runnable.
|
||||
"""
|
||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||
|
||||
if isinstance(func, itemgetter):
|
||||
# This is terrible, but afaict it's not possible to access _items
|
||||
# on itemgetter objects, so we have to parse the repr
|
||||
@@ -4094,15 +4090,13 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
if all(
|
||||
item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items
|
||||
):
|
||||
fields = {item[1:-1]: (Any, ...) for item in items}
|
||||
# It's a dict, lol
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
|
||||
)
|
||||
return create_model(self.get_name("Input"), **fields)
|
||||
else:
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
__root__=(List[Any], None),
|
||||
__root__=List[Any],
|
||||
)
|
||||
|
||||
if self.InputType != Any:
|
||||
@@ -4111,7 +4105,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
if dict_keys := get_function_first_arg_dict_keys(func):
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
||||
**{key: Any for key in dict_keys}, # type: ignore
|
||||
)
|
||||
|
||||
return super().get_input_schema(config)
|
||||
@@ -4664,13 +4658,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
schema = self.bound.get_output_schema(config)
|
||||
return create_model(
|
||||
self.get_name("Output"),
|
||||
__root__=(
|
||||
List[schema], # type: ignore
|
||||
None,
|
||||
),
|
||||
)
|
||||
return create_model(self.get_name("Output"), __root__=List[schema])
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
|
||||
@@ -9,9 +9,6 @@ import textwrap
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from itertools import groupby
|
||||
from pydantic import BaseModel, ConfigDict, RootModel
|
||||
from pydantic import create_model as _create_model_base
|
||||
from pydantic.json_schema import DEFAULT_REF_TEMPLATE
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
@@ -32,6 +29,10 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, RootModel
|
||||
from pydantic import create_model as _create_model_base
|
||||
from pydantic.json_schema import DEFAULT_REF_TEMPLATE
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
@@ -697,7 +698,7 @@ _SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
||||
NO_DEFAULT = object()
|
||||
|
||||
|
||||
def create_base_class(name: str, type_, default_=NO_DEFAULT) -> Type:
|
||||
def create_base_class(name: str, type_, default_=NO_DEFAULT) -> Type[BaseModel]:
|
||||
if default_ is NO_DEFAULT:
|
||||
|
||||
class FixedNameRootModel(RootModel):
|
||||
|
||||
@@ -38,7 +38,6 @@ from langchain_core.language_models import (
|
||||
from langchain_core.load import dumpd, dumps
|
||||
from langchain_core.load.load import loads
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
@@ -620,12 +619,25 @@ def test_with_types_with_type_generics() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_schema_with_itemgetter() -> None:
|
||||
"""Test runnable with itemgetter."""
|
||||
foo = RunnableLambda(itemgetter("hello"))
|
||||
assert foo.input_schema.schema() == {
|
||||
"properties": {"hello": {"title": "Hello"}},
|
||||
"required": ["hello"],
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
}
|
||||
prompt = ChatPromptTemplate.from_template("what is {language}?")
|
||||
chain: Runnable = {"language": itemgetter("language")} | prompt
|
||||
assert chain.input_schema.schema() == {}
|
||||
|
||||
|
||||
def test_schema_complex_seq() -> None:
|
||||
prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")
|
||||
prompt2 = ChatPromptTemplate.from_template(
|
||||
"what country is the city {city} in? respond in {language}"
|
||||
)
|
||||
|
||||
model = FakeListChatModel(responses=[""])
|
||||
|
||||
chain1: Runnable = RunnableSequence(
|
||||
@@ -648,8 +660,8 @@ def test_schema_complex_seq() -> None:
|
||||
"person": {"title": "Person", "type": "string"},
|
||||
"language": {"title": "Language"},
|
||||
},
|
||||
"required": ["person", "language"],
|
||||
}
|
||||
|
||||
assert chain2.output_schema.schema() == {
|
||||
"title": "StrOutputParserOutput",
|
||||
"type": "string",
|
||||
@@ -3072,6 +3084,43 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
|
||||
assert len(map_run.child_runs) == 3
|
||||
|
||||
|
||||
def test_schemas_2():
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
|
||||
chat_res = "i'm a chatbot"
|
||||
# sleep to better simulate a real stream
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
chain: Runnable = prompt | {
|
||||
"llm": llm,
|
||||
"passthrough": RunnablePassthrough(),
|
||||
}
|
||||
chain_pick_one = chain.pick("llm")
|
||||
|
||||
assert chain_pick_one.output_schema.schema() == {
|
||||
"title": "RunnableSequenceOutput",
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
|
||||
def test_foo():
|
||||
"""Test create model."""
|
||||
from pydantic import RootModel, create_model
|
||||
|
||||
class Foo(RootModel):
|
||||
pass
|
||||
|
||||
meow = Foo[str]
|
||||
model = create_model("meow", **{"llm": (meow, ...)})
|
||||
pass
|
||||
|
||||
|
||||
def test_map_stream() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
@@ -3164,7 +3213,7 @@ def test_map_stream() -> None:
|
||||
|
||||
assert streamed_chunks[0] in [
|
||||
{"llm": "i"},
|
||||
{"chat": AIMessageChunk(content="i")},
|
||||
{"chat": _AnyIdAIMessageChunk(content="i")},
|
||||
]
|
||||
assert len(streamed_chunks) == len(llm_res) + len(chat_res)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user