mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
core[patch]: Automatic upgrade to AddableDict in transform and atransform (#18743)
Automatic upgrade to transform and atransform Closes: https://github.com/langchain-ai/langchain/issues/18741 https://github.com/langchain-ai/langgraph/issues/136 https://github.com/langchain-ai/langserve/issues/504
This commit is contained in:
parent
fee6f983ef
commit
6caceb5473
@ -1050,12 +1050,19 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
for chunk in input:
|
for chunk in input:
|
||||||
if not got_first_val:
|
if not got_first_val:
|
||||||
final = chunk
|
final = _adapt_first_streaming_chunk(chunk) # type: ignore
|
||||||
got_first_val = True
|
got_first_val = True
|
||||||
else:
|
else:
|
||||||
# Make a best effort to gather, for any type that supports `+`
|
# Make a best effort to gather, for any type that supports `+`
|
||||||
# This method should throw an error if gathering fails.
|
# This method should throw an error if gathering fails.
|
||||||
final = final + chunk # type: ignore[operator]
|
try:
|
||||||
|
final = final + chunk # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
raise TypeError(
|
||||||
|
f"Failed while trying to add together "
|
||||||
|
f"type {type(final)} and {type(chunk)}."
|
||||||
|
f"These types should be addable for transform to work."
|
||||||
|
)
|
||||||
|
|
||||||
if got_first_val:
|
if got_first_val:
|
||||||
yield from self.stream(final, config, **kwargs)
|
yield from self.stream(final, config, **kwargs)
|
||||||
@ -1076,12 +1083,19 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
async for chunk in input:
|
async for chunk in input:
|
||||||
if not got_first_val:
|
if not got_first_val:
|
||||||
final = chunk
|
final = _adapt_first_streaming_chunk(chunk) # type: ignore
|
||||||
got_first_val = True
|
got_first_val = True
|
||||||
else:
|
else:
|
||||||
# Make a best effort to gather, for any type that supports `+`
|
# Make a best effort to gather, for any type that supports `+`
|
||||||
# This method should throw an error if gathering fails.
|
# This method should throw an error if gathering fails.
|
||||||
final = final + chunk # type: ignore[operator]
|
try:
|
||||||
|
final = final + chunk # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
raise TypeError(
|
||||||
|
f"Failed while trying to add together "
|
||||||
|
f"type {type(final)} and {type(chunk)}."
|
||||||
|
f"These types should be addable for atransform to work."
|
||||||
|
)
|
||||||
|
|
||||||
if got_first_val:
|
if got_first_val:
|
||||||
async for output in self.astream(final, config, **kwargs):
|
async for output in self.astream(final, config, **kwargs):
|
||||||
@ -3560,7 +3574,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
final: Optional[Input] = None
|
final: Optional[Input] = None
|
||||||
for ichunk in input:
|
for ichunk in input:
|
||||||
if final is None:
|
if final is None:
|
||||||
final = ichunk
|
final = _adapt_first_streaming_chunk(ichunk) # type: ignore
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
final = final + ichunk # type: ignore[operator]
|
final = final + ichunk # type: ignore[operator]
|
||||||
@ -3644,7 +3658,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
final: Optional[Input] = None
|
final: Optional[Input] = None
|
||||||
async for ichunk in input:
|
async for ichunk in input:
|
||||||
if final is None:
|
if final is None:
|
||||||
final = ichunk
|
final = _adapt_first_streaming_chunk(ichunk)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
final = final + ichunk # type: ignore[operator]
|
final = final + ichunk # type: ignore[operator]
|
||||||
@ -4445,3 +4459,11 @@ def chain(
|
|||||||
yield chunk
|
yield chunk
|
||||||
"""
|
"""
|
||||||
return RunnableLambda(func)
|
return RunnableLambda(func)
|
||||||
|
|
||||||
|
|
||||||
|
def _adapt_first_streaming_chunk(chunk: Any) -> Any:
|
||||||
|
"""This might transform the first chunk of a stream into an AddableDict."""
|
||||||
|
if isinstance(chunk, dict) and not isinstance(chunk, AddableDict):
|
||||||
|
return AddableDict(chunk)
|
||||||
|
else:
|
||||||
|
return chunk
|
||||||
|
@ -165,7 +165,7 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
streaming.
|
streaming.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages: Iterator[AIMessage]
|
messages: Iterator[Union[AIMessage, str]]
|
||||||
"""Get an iterator over messages.
|
"""Get an iterator over messages.
|
||||||
|
|
||||||
This can be expanded to accept other types like Callables / dicts / strings
|
This can be expanded to accept other types like Callables / dicts / strings
|
||||||
@ -187,7 +187,11 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
message = next(self.messages)
|
message = next(self.messages)
|
||||||
generation = ChatGeneration(message=message)
|
if isinstance(message, str):
|
||||||
|
message_ = AIMessage(content=message)
|
||||||
|
else:
|
||||||
|
message_ = message
|
||||||
|
generation = ChatGeneration(message=message_)
|
||||||
return ChatResult(generations=[generation])
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
|
@ -70,6 +70,7 @@ from langchain_core.runnables import (
|
|||||||
chain,
|
chain,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.base import RunnableSerializable
|
from langchain_core.runnables.base import RunnableSerializable
|
||||||
|
from langchain_core.runnables.utils import Input, Output
|
||||||
from langchain_core.tools import BaseTool, tool
|
from langchain_core.tools import BaseTool, tool
|
||||||
from langchain_core.tracers import (
|
from langchain_core.tracers import (
|
||||||
BaseTracer,
|
BaseTracer,
|
||||||
@ -5183,3 +5184,70 @@ async def test_astream_log_deep_copies() -> None:
|
|||||||
"name": "add_one",
|
"name": "add_one",
|
||||||
"type": "chain",
|
"type": "chain",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_of_runnable_lambda_with_dicts() -> None:
|
||||||
|
"""Test transform of runnable lamdbda."""
|
||||||
|
runnable = RunnableLambda(lambda x: x)
|
||||||
|
chunks = iter(
|
||||||
|
[
|
||||||
|
{"foo": "a"},
|
||||||
|
{"foo": "n"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert list(runnable.transform(chunks)) == [{"foo": "an"}]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_atransform_of_runnable_lambda_with_dicts() -> None:
|
||||||
|
async def identity(x: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
"""Return x."""
|
||||||
|
return x
|
||||||
|
|
||||||
|
runnable = RunnableLambda[Dict[str, str], Dict[str, str]](identity)
|
||||||
|
|
||||||
|
async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:
|
||||||
|
yield {"foo": "a"}
|
||||||
|
yield {"foo": "n"}
|
||||||
|
|
||||||
|
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
|
||||||
|
assert chunks == [{"foo": "an"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_transform_with_dicts() -> None:
|
||||||
|
"""Test that default transform works with dicts."""
|
||||||
|
|
||||||
|
class CustomRunnable(RunnableSerializable[Input, Output]):
|
||||||
|
def invoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Output:
|
||||||
|
return cast(Output, input) # type: ignore
|
||||||
|
|
||||||
|
runnable = CustomRunnable[Dict[str, str], Dict[str, str]]()
|
||||||
|
chunks = iter(
|
||||||
|
[
|
||||||
|
{"foo": "a"},
|
||||||
|
{"foo": "n"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert list(runnable.transform(chunks)) == [{"foo": "an"}]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_defualt_atransform_with_dicts() -> None:
|
||||||
|
"""Test that default transform works with dicts."""
|
||||||
|
|
||||||
|
class CustomRunnable(RunnableSerializable[Input, Output]):
|
||||||
|
def invoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Output:
|
||||||
|
return cast(Output, input)
|
||||||
|
|
||||||
|
runnable = CustomRunnable[Dict[str, str], Dict[str, str]]()
|
||||||
|
|
||||||
|
async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:
|
||||||
|
yield {"foo": "a"}
|
||||||
|
yield {"foo": "n"}
|
||||||
|
|
||||||
|
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
|
||||||
|
|
||||||
|
assert chunks == [{"foo": "an"}]
|
||||||
|
Loading…
Reference in New Issue
Block a user