From fc46822bb764d08df42fee33550a1cfcea061589 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 29 Aug 2024 11:38:18 -0700 Subject: [PATCH] wip transform --- libs/core/langchain_core/messages/utils.py | 179 +++++++++--------- libs/core/langchain_core/runnables/base.py | 6 +- .../langchain_anthropic/chat_models.py | 4 +- 3 files changed, 95 insertions(+), 94 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 3ad3d31699a..c7fb622bd25 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -28,7 +28,7 @@ from typing import ( Type, Union, cast, - overload, + overload, Iterator, ) from langchain_core.messages.ai import AIMessage, AIMessageChunk @@ -41,6 +41,7 @@ from langchain_core.messages.system import SystemMessage, SystemMessageChunk from langchain_core.messages.tool import ToolMessage, ToolMessageChunk from langchain_core.messages.tool import ( tool_call as create_tool_call, + tool_call_chunk as create_tool_call_chunk ) if TYPE_CHECKING: @@ -334,84 +335,77 @@ def convert_to_messages( return [_convert_to_message(m, copy=copy) for m in messages] -def _runnable_support(*args: Callable, supports_single: bool = False) -> Callable: - if supports_single: +def _runnable_generator(func: Callable) -> Callable: + @overload + def wrapped( + messages: Literal[None] = None, **kwargs: Any + ) -> Runnable[ + Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], + Union[BaseMessage, List[BaseMessage]], + ]: ... - def runnable_support(func: Callable) -> Callable: - @overload - def wrapped( - messages: Literal[None] = None, **kwargs: Any - ) -> Runnable[ - Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], - Union[BaseMessage, List[BaseMessage]], - ]: ... + @overload + def wrapped( + messages: Sequence[Union[BaseMessage, Dict, Tuple]], **kwargs: Any + ) -> List[BaseMessage]: ... - @overload - def wrapped( - messages: Sequence[Union[BaseMessage, Dict, Tuple]], **kwargs: Any - ) -> List[BaseMessage]: ... + @overload + def wrapped(messages: MessageLikeRepresentation, **kwargs: Any) -> BaseMessage: ... - @overload - def wrapped( - messages: MessageLikeRepresentation, **kwargs: Any - ) -> BaseMessage: ... + def wrapped( + messages: Union[ + MessageLikeRepresentation, Sequence[MessageLikeRepresentation], None + ] = None, + **kwargs: Any, + ) -> Union[ + BaseMessage, + List[BaseMessage], + Runnable[ + Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], + Union[BaseMessage, List[BaseMessage]], + ], + ]: + from langchain_core.runnables.base import RunnableGenerator - def wrapped( - messages: Union[ - MessageLikeRepresentation, Sequence[MessageLikeRepresentation], None - ] = None, - **kwargs: Any, - ) -> Union[ - BaseMessage, - List[BaseMessage], - Runnable[ - Union[ - MessageLikeRepresentation, Sequence[MessageLikeRepresentation] - ], - Union[BaseMessage, List[BaseMessage]], - ], - ]: - from langchain_core.runnables.base import RunnableLambda + if messages is not None: + return func(messages, **kwargs) + else: + def transform(input_: Iterator, **kwargs: Any) -> Iterator: + for x in input_: + yield func(x, **kwargs) + return RunnableGenerator(partial(transform, **kwargs), name=func.__name__) - if messages is not None: - return func(messages, **kwargs) - else: - return RunnableLambda(partial(func, **kwargs), name=func.__name__) + wrapped.__doc__ = func.__doc__ + return wrapped - wrapped.__doc__ = func.__doc__ - return wrapped - else: +def _runnable_support(func: Callable) -> Callable: + @overload + def wrapped( + messages: Literal[None] = None, **kwargs: Any + ) -> Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]]: ... - def runnable_support(func: Callable) -> Callable: - @overload - def wrapped( - messages: Literal[None] = None, **kwargs: Any - ) -> Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]]: ... + @overload + def wrapped( + messages: Sequence[MessageLikeRepresentation], **kwargs: Any + ) -> List[BaseMessage]: ... - @overload - def wrapped( - messages: Sequence[MessageLikeRepresentation], **kwargs: Any - ) -> List[BaseMessage]: ... + def wrapped( + messages: Union[Sequence[MessageLikeRepresentation], None] = None, + **kwargs: Any, + ) -> Union[ + Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]], + List[BaseMessage], + ]: + from langchain_core.runnables.base import RunnableLambda - def wrapped( - messages: Union[Sequence[MessageLikeRepresentation], None] = None, - **kwargs: Any, - ) -> Union[ - Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]], - List[BaseMessage], - ]: - from langchain_core.runnables.base import RunnableLambda + if messages is not None: + return func(messages, **kwargs) + else: + return RunnableLambda(partial(func, **kwargs), name=func.__name__) - if messages is not None: - return func(messages, **kwargs) - else: - return RunnableLambda(partial(func, **kwargs), name=func.__name__) - - wrapped.__doc__ = func.__doc__ - return wrapped - - return runnable_support(*args) if args else cast(Callable, runnable_support) + wrapped.__doc__ = func.__doc__ + return wrapped @_runnable_support @@ -909,13 +903,13 @@ def trim_messages( ) -@_runnable_support(supports_single=True) +@_runnable_generator def format_messages_as( - messages: Union[MessageLikeRepresentation, Iterable[MessageLikeRepresentation]], + messages: Union[MessageLikeRepresentation, Iterator[MessageLikeRepresentation]], *, format: Literal["openai", "anthropic"], text: Literal["string", "block"], -) -> Union[BaseMessage, List[BaseMessage]]: +) -> Union[BaseMessage, Iterator[BaseMessage]]: """Convert message contents into a standard format. .. versionadded:: 0.2.36 @@ -1123,7 +1117,6 @@ def _format_contents_as_openai( ) # Bedrock converse elif image := block.get("image"): - raise ValueError("1064") if missing := [ k for k in ("source", "format") if k not in image ]: @@ -1153,26 +1146,30 @@ def _format_contents_as_openai( f"content block:\n\n{block}" ) elif block.get("type") == "tool_use": - if missing := [ - k for k in ("id", "name", "input") if k not in block - ]: - raise ValueError( - f"Unrecognized content block at " - f"messages[{i}].content[{j}] has 'type': 'tool_use', " - f"but is missing expected key(s) " - f"{missing}. Full content block:\n\n{block}" - ) - if not any( - tool_call["id"] == block["id"] - for tool_call in cast(AIMessage, message).tool_calls - ): - cast(AIMessage, message).tool_calls.append( - create_tool_call( - name=block["name"], - id=block["id"], - args=block["input"], + if not isinstance(message, BaseMessageChunk): + if missing := [ + k for k in ("id", "name", "input") if k not in block + ]: + raise ValueError( + f"Unrecognized content block at " + f"messages[{i}].content[{j}] has 'type': 'tool_use', " + f"but is missing expected key(s) " + f"{missing}. Full content block:\n\n{block}" ) - ) + if not any( + tool_call["id"] == block["id"] + for tool_call in cast(AIMessage, message).tool_calls + ): + cast(AIMessage, message).tool_calls.append( + create_tool_call( + name=block["name"], + id=block["id"], + args=block["input"], + ) + ) + else: + if not message.tool_call_chunks: + message.tool_call_chunks = [create_tool_call_chunk(id=block.get("id"), index=block.get("index"), args=block.get("partial_json"), name=block.get("name"))] elif block.get("type") == "tool_result": if missing := [ k for k in ("content", "tool_use_id") if k not in block diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 89cc2822690..42cdda0911e 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -3882,6 +3882,7 @@ class RunnableGenerator(Runnable[Input, Output]): atransform: Optional[ Callable[[AsyncIterator[Input]], AsyncIterator[Output]] ] = None, + name: Optional[str] = None, ) -> None: """Initialize a RunnableGenerator. @@ -3909,7 +3910,10 @@ class RunnableGenerator(Runnable[Input, Output]): ) try: - self.name = func_for_name.__name__ + if name: + self.name = name + else: + self.name = func_for_name.__name__ except AttributeError: pass diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index ff931efb7c4..68bfca2eefa 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1210,12 +1210,12 @@ def _make_message_chunk_from_anthropic_event( content_block = event.delta.model_dump() content_block["index"] = event.index content_block["type"] = "tool_use" - tool_call_chunk = { + tool_call_chunk = create_tool_call_chunk(**{ "index": event.index, "id": None, "name": None, "args": event.delta.partial_json, - } + }) message_chunk = AIMessageChunk( content=[content_block], tool_call_chunks=[tool_call_chunk], # type: ignore