wip transform

This commit is contained in:
Bagatur 2024-08-29 11:38:18 -07:00
parent ad009f9c69
commit fc46822bb7
3 changed files with 95 additions and 94 deletions

View File

@ -28,7 +28,7 @@ from typing import (
Type, Type,
Union, Union,
cast, cast,
overload, overload, Iterator,
) )
from langchain_core.messages.ai import AIMessage, AIMessageChunk from langchain_core.messages.ai import AIMessage, AIMessageChunk
@ -41,6 +41,7 @@ from langchain_core.messages.system import SystemMessage, SystemMessageChunk
from langchain_core.messages.tool import ToolMessage, ToolMessageChunk from langchain_core.messages.tool import ToolMessage, ToolMessageChunk
from langchain_core.messages.tool import ( from langchain_core.messages.tool import (
tool_call as create_tool_call, tool_call as create_tool_call,
tool_call_chunk as create_tool_call_chunk
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -334,84 +335,77 @@ def convert_to_messages(
return [_convert_to_message(m, copy=copy) for m in messages] return [_convert_to_message(m, copy=copy) for m in messages]
def _runnable_support(*args: Callable, supports_single: bool = False) -> Callable: def _runnable_generator(func: Callable) -> Callable:
if supports_single: @overload
def wrapped(
messages: Literal[None] = None, **kwargs: Any
) -> Runnable[
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
Union[BaseMessage, List[BaseMessage]],
]: ...
def runnable_support(func: Callable) -> Callable: @overload
@overload def wrapped(
def wrapped( messages: Sequence[Union[BaseMessage, Dict, Tuple]], **kwargs: Any
messages: Literal[None] = None, **kwargs: Any ) -> List[BaseMessage]: ...
) -> Runnable[
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
Union[BaseMessage, List[BaseMessage]],
]: ...
@overload @overload
def wrapped( def wrapped(messages: MessageLikeRepresentation, **kwargs: Any) -> BaseMessage: ...
messages: Sequence[Union[BaseMessage, Dict, Tuple]], **kwargs: Any
) -> List[BaseMessage]: ...
@overload def wrapped(
def wrapped( messages: Union[
messages: MessageLikeRepresentation, **kwargs: Any MessageLikeRepresentation, Sequence[MessageLikeRepresentation], None
) -> BaseMessage: ... ] = None,
**kwargs: Any,
) -> Union[
BaseMessage,
List[BaseMessage],
Runnable[
Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]],
Union[BaseMessage, List[BaseMessage]],
],
]:
from langchain_core.runnables.base import RunnableGenerator
def wrapped( if messages is not None:
messages: Union[ return func(messages, **kwargs)
MessageLikeRepresentation, Sequence[MessageLikeRepresentation], None else:
] = None, def transform(input_: Iterator, **kwargs: Any) -> Iterator:
**kwargs: Any, for x in input_:
) -> Union[ yield func(x, **kwargs)
BaseMessage, return RunnableGenerator(partial(transform, **kwargs), name=func.__name__)
List[BaseMessage],
Runnable[
Union[
MessageLikeRepresentation, Sequence[MessageLikeRepresentation]
],
Union[BaseMessage, List[BaseMessage]],
],
]:
from langchain_core.runnables.base import RunnableLambda
if messages is not None: wrapped.__doc__ = func.__doc__
return func(messages, **kwargs) return wrapped
else:
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
wrapped.__doc__ = func.__doc__
return wrapped
else: def _runnable_support(func: Callable) -> Callable:
@overload
def wrapped(
messages: Literal[None] = None, **kwargs: Any
) -> Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]]: ...
def runnable_support(func: Callable) -> Callable: @overload
@overload def wrapped(
def wrapped( messages: Sequence[MessageLikeRepresentation], **kwargs: Any
messages: Literal[None] = None, **kwargs: Any ) -> List[BaseMessage]: ...
) -> Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]]: ...
@overload def wrapped(
def wrapped( messages: Union[Sequence[MessageLikeRepresentation], None] = None,
messages: Sequence[MessageLikeRepresentation], **kwargs: Any **kwargs: Any,
) -> List[BaseMessage]: ... ) -> Union[
Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]],
List[BaseMessage],
]:
from langchain_core.runnables.base import RunnableLambda
def wrapped( if messages is not None:
messages: Union[Sequence[MessageLikeRepresentation], None] = None, return func(messages, **kwargs)
**kwargs: Any, else:
) -> Union[ return RunnableLambda(partial(func, **kwargs), name=func.__name__)
Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]],
List[BaseMessage],
]:
from langchain_core.runnables.base import RunnableLambda
if messages is not None: wrapped.__doc__ = func.__doc__
return func(messages, **kwargs) return wrapped
else:
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
wrapped.__doc__ = func.__doc__
return wrapped
return runnable_support(*args) if args else cast(Callable, runnable_support)
@_runnable_support @_runnable_support
@ -909,13 +903,13 @@ def trim_messages(
) )
@_runnable_support(supports_single=True) @_runnable_generator
def format_messages_as( def format_messages_as(
messages: Union[MessageLikeRepresentation, Iterable[MessageLikeRepresentation]], messages: Union[MessageLikeRepresentation, Iterator[MessageLikeRepresentation]],
*, *,
format: Literal["openai", "anthropic"], format: Literal["openai", "anthropic"],
text: Literal["string", "block"], text: Literal["string", "block"],
) -> Union[BaseMessage, List[BaseMessage]]: ) -> Union[BaseMessage, Iterator[BaseMessage]]:
"""Convert message contents into a standard format. """Convert message contents into a standard format.
.. versionadded:: 0.2.36 .. versionadded:: 0.2.36
@ -1123,7 +1117,6 @@ def _format_contents_as_openai(
) )
# Bedrock converse # Bedrock converse
elif image := block.get("image"): elif image := block.get("image"):
raise ValueError("1064")
if missing := [ if missing := [
k for k in ("source", "format") if k not in image k for k in ("source", "format") if k not in image
]: ]:
@ -1153,26 +1146,30 @@ def _format_contents_as_openai(
f"content block:\n\n{block}" f"content block:\n\n{block}"
) )
elif block.get("type") == "tool_use": elif block.get("type") == "tool_use":
if missing := [ if not isinstance(message, BaseMessageChunk):
k for k in ("id", "name", "input") if k not in block if missing := [
]: k for k in ("id", "name", "input") if k not in block
raise ValueError( ]:
f"Unrecognized content block at " raise ValueError(
f"messages[{i}].content[{j}] has 'type': 'tool_use', " f"Unrecognized content block at "
f"but is missing expected key(s) " f"messages[{i}].content[{j}] has 'type': 'tool_use', "
f"{missing}. Full content block:\n\n{block}" f"but is missing expected key(s) "
) f"{missing}. Full content block:\n\n{block}"
if not any(
tool_call["id"] == block["id"]
for tool_call in cast(AIMessage, message).tool_calls
):
cast(AIMessage, message).tool_calls.append(
create_tool_call(
name=block["name"],
id=block["id"],
args=block["input"],
) )
) if not any(
tool_call["id"] == block["id"]
for tool_call in cast(AIMessage, message).tool_calls
):
cast(AIMessage, message).tool_calls.append(
create_tool_call(
name=block["name"],
id=block["id"],
args=block["input"],
)
)
else:
if not message.tool_call_chunks:
message.tool_call_chunks = [create_tool_call_chunk(id=block.get("id"), index=block.get("index"), args=block.get("partial_json"), name=block.get("name"))]
elif block.get("type") == "tool_result": elif block.get("type") == "tool_result":
if missing := [ if missing := [
k for k in ("content", "tool_use_id") if k not in block k for k in ("content", "tool_use_id") if k not in block

View File

@ -3882,6 +3882,7 @@ class RunnableGenerator(Runnable[Input, Output]):
atransform: Optional[ atransform: Optional[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]] Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
] = None, ] = None,
name: Optional[str] = None,
) -> None: ) -> None:
"""Initialize a RunnableGenerator. """Initialize a RunnableGenerator.
@ -3909,7 +3910,10 @@ class RunnableGenerator(Runnable[Input, Output]):
) )
try: try:
self.name = func_for_name.__name__ if name:
self.name = name
else:
self.name = func_for_name.__name__
except AttributeError: except AttributeError:
pass pass

View File

@ -1210,12 +1210,12 @@ def _make_message_chunk_from_anthropic_event(
content_block = event.delta.model_dump() content_block = event.delta.model_dump()
content_block["index"] = event.index content_block["index"] = event.index
content_block["type"] = "tool_use" content_block["type"] = "tool_use"
tool_call_chunk = { tool_call_chunk = create_tool_call_chunk(**{
"index": event.index, "index": event.index,
"id": None, "id": None,
"name": None, "name": None,
"args": event.delta.partial_json, "args": event.delta.partial_json,
} })
message_chunk = AIMessageChunk( message_chunk = AIMessageChunk(
content=[content_block], content=[content_block],
tool_call_chunks=[tool_call_chunk], # type: ignore tool_call_chunks=[tool_call_chunk], # type: ignore