mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
core[patch]: Automatic upgrade to AddableDict in transform and atransform (#18743)
Automatic upgrade to transform and atransform Closes: https://github.com/langchain-ai/langchain/issues/18741 https://github.com/langchain-ai/langgraph/issues/136 https://github.com/langchain-ai/langserve/issues/504
This commit is contained in:
@@ -165,7 +165,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
streaming.
|
||||
"""
|
||||
|
||||
messages: Iterator[AIMessage]
|
||||
messages: Iterator[Union[AIMessage, str]]
|
||||
"""Get an iterator over messages.
|
||||
|
||||
This can be expanded to accept other types like Callables / dicts / strings
|
||||
@@ -187,7 +187,11 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
message = next(self.messages)
|
||||
generation = ChatGeneration(message=message)
|
||||
if isinstance(message, str):
|
||||
message_ = AIMessage(content=message)
|
||||
else:
|
||||
message_ = message
|
||||
generation = ChatGeneration(message=message_)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
def _stream(
|
||||
|
@@ -70,6 +70,7 @@ from langchain_core.runnables import (
|
||||
chain,
|
||||
)
|
||||
from langchain_core.runnables.base import RunnableSerializable
|
||||
from langchain_core.runnables.utils import Input, Output
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
from langchain_core.tracers import (
|
||||
BaseTracer,
|
||||
@@ -5183,3 +5184,70 @@ async def test_astream_log_deep_copies() -> None:
|
||||
"name": "add_one",
|
||||
"type": "chain",
|
||||
}
|
||||
|
||||
|
||||
def test_transform_of_runnable_lambda_with_dicts() -> None:
|
||||
"""Test transform of runnable lamdbda."""
|
||||
runnable = RunnableLambda(lambda x: x)
|
||||
chunks = iter(
|
||||
[
|
||||
{"foo": "a"},
|
||||
{"foo": "n"},
|
||||
]
|
||||
)
|
||||
assert list(runnable.transform(chunks)) == [{"foo": "an"}]
|
||||
|
||||
|
||||
async def test_atransform_of_runnable_lambda_with_dicts() -> None:
|
||||
async def identity(x: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Return x."""
|
||||
return x
|
||||
|
||||
runnable = RunnableLambda[Dict[str, str], Dict[str, str]](identity)
|
||||
|
||||
async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:
|
||||
yield {"foo": "a"}
|
||||
yield {"foo": "n"}
|
||||
|
||||
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
|
||||
assert chunks == [{"foo": "an"}]
|
||||
|
||||
|
||||
def test_default_transform_with_dicts() -> None:
|
||||
"""Test that default transform works with dicts."""
|
||||
|
||||
class CustomRunnable(RunnableSerializable[Input, Output]):
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
return cast(Output, input) # type: ignore
|
||||
|
||||
runnable = CustomRunnable[Dict[str, str], Dict[str, str]]()
|
||||
chunks = iter(
|
||||
[
|
||||
{"foo": "a"},
|
||||
{"foo": "n"},
|
||||
]
|
||||
)
|
||||
|
||||
assert list(runnable.transform(chunks)) == [{"foo": "an"}]
|
||||
|
||||
|
||||
async def test_defualt_atransform_with_dicts() -> None:
|
||||
"""Test that default transform works with dicts."""
|
||||
|
||||
class CustomRunnable(RunnableSerializable[Input, Output]):
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
return cast(Output, input)
|
||||
|
||||
runnable = CustomRunnable[Dict[str, str], Dict[str, str]]()
|
||||
|
||||
async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:
|
||||
yield {"foo": "a"}
|
||||
yield {"foo": "n"}
|
||||
|
||||
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
|
||||
|
||||
assert chunks == [{"foo": "an"}]
|
||||
|
Reference in New Issue
Block a user