From 970fe23feb217dec8237576e3b69f64508d715d9 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 28 Nov 2023 21:49:43 +0000 Subject: [PATCH] Fixes for opengpts release (#13960) --- .../language_models/chat_models.py | 19 +++- libs/core/langchain_core/messages/__init__.py | 10 ++ .../langchain_core/runnables/configurable.py | 96 ++++++++++++------- .../langchain_core/runnables/passthrough.py | 7 +- .../tests/unit_tests/messages/test_imports.py | 1 + libs/core/tests/unit_tests/test_messages.py | 16 ++++ .../langchain/langchain/chat_models/openai.py | 12 ++- 7 files changed, 123 insertions(+), 38 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index abf7928d664..69bede4a12a 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -34,6 +34,7 @@ from langchain_core.messages import ( BaseMessage, BaseMessageChunk, HumanMessage, + message_chunk_to_message, ) from langchain_core.outputs import ( ChatGeneration, @@ -63,7 +64,14 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult: else: generation += chunk assert generation is not None - return ChatResult(generations=[generation]) + return ChatResult( + generations=[ + ChatGeneration( + message=message_chunk_to_message(generation.message), + generation_info=generation.generation_info, + ) + ] + ) async def agenerate_from_stream( @@ -76,7 +84,14 @@ async def agenerate_from_stream( else: generation += chunk assert generation is not None - return ChatResult(generations=[generation]) + return ChatResult( + generations=[ + ChatGeneration( + message=message_chunk_to_message(generation.message), + generation_info=generation.generation_info, + ) + ] + ) class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): diff --git a/libs/core/langchain_core/messages/__init__.py b/libs/core/langchain_core/messages/__init__.py index 7bcd9bfed5f..cc8bd4d21a9 100644 --- a/libs/core/langchain_core/messages/__init__.py +++ b/libs/core/langchain_core/messages/__init__.py @@ -98,6 +98,15 @@ def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]: return [_message_from_dict(m) for m in messages] +def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage: + if not isinstance(chunk, BaseMessageChunk): + return chunk + # chunk classes always have the equivalent non-chunk class as their first parent + return chunk.__class__.__mro__[1]( + **{k: v for k, v in chunk.__dict__.items() if k != "type"} + ) + + __all__ = [ "AIMessage", "AIMessageChunk", @@ -115,6 +124,7 @@ __all__ = [ "ToolMessage", "ToolMessageChunk", "get_buffer_string", + "message_chunk_to_message", "messages_from_dict", "messages_to_dict", "message_to_dict", diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index 1c9756b7939..ef7f546bbc1 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -12,6 +12,7 @@ from typing import ( List, Optional, Sequence, + Tuple, Type, Union, cast, @@ -65,28 +66,32 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: - return self._prepare(config).get_input_schema(config) + runnable, config = self._prepare(config) + return runnable.get_input_schema(config) def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: - return self._prepare(config).get_output_schema(config) + runnable, config = self._prepare(config) + return runnable.get_output_schema(config) @abstractmethod def _prepare( self, config: Optional[RunnableConfig] = None - ) -> Runnable[Input, Output]: + ) -> Tuple[Runnable[Input, Output], RunnableConfig]: ... def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: - return self._prepare(config).invoke(input, config, **kwargs) + runnable, config = self._prepare(config) + return runnable.invoke(input, config, **kwargs) async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: - return await self._prepare(config).ainvoke(input, config, **kwargs) + runnable, config = self._prepare(config) + return await runnable.ainvoke(input, config, **kwargs) def batch( self, @@ -99,21 +104,22 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): configs = get_config_list(config, len(inputs)) prepared = [self._prepare(c) for c in configs] - if all(p is self.default for p in prepared): + if all(p is self.default for p, _ in prepared): return self.default.batch( - inputs, config, return_exceptions=return_exceptions, **kwargs + inputs, + [c for _, c in prepared], + return_exceptions=return_exceptions, + **kwargs, ) if not inputs: return [] - configs = get_config_list(config, len(inputs)) - def invoke( - bound: Runnable[Input, Output], + prepared: Tuple[Runnable[Input, Output], RunnableConfig], input: Input, - config: RunnableConfig, ) -> Union[Output, Exception]: + bound, config = prepared if return_exceptions: try: return bound.invoke(input, config, **kwargs) @@ -124,12 +130,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): # If there's only one input, don't bother with the executor if len(inputs) == 1: - return cast(List[Output], [invoke(prepared[0], inputs[0], configs[0])]) + return cast(List[Output], [invoke(prepared[0], inputs[0])]) with get_executor_for_config(configs[0]) as executor: - return cast( - List[Output], list(executor.map(invoke, prepared, inputs, configs)) - ) + return cast(List[Output], list(executor.map(invoke, prepared, inputs))) async def abatch( self, @@ -142,21 +146,22 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): configs = get_config_list(config, len(inputs)) prepared = [self._prepare(c) for c in configs] - if all(p is self.default for p in prepared): + if all(p is self.default for p, _ in prepared): return await self.default.abatch( - inputs, config, return_exceptions=return_exceptions, **kwargs + inputs, + [c for _, c in prepared], + return_exceptions=return_exceptions, + **kwargs, ) if not inputs: return [] - configs = get_config_list(config, len(inputs)) - async def ainvoke( - bound: Runnable[Input, Output], + prepared: Tuple[Runnable[Input, Output], RunnableConfig], input: Input, - config: RunnableConfig, ) -> Union[Output, Exception]: + bound, config = prepared if return_exceptions: try: return await bound.ainvoke(input, config, **kwargs) @@ -165,7 +170,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): else: return await bound.ainvoke(input, config, **kwargs) - coros = map(ainvoke, prepared, inputs, configs) + coros = map(ainvoke, prepared, inputs) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) def stream( @@ -174,7 +179,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: - return self._prepare(config).stream(input, config, **kwargs) + runnable, config = self._prepare(config) + return runnable.stream(input, config, **kwargs) async def astream( self, @@ -182,7 +188,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Output]: - async for chunk in self._prepare(config).astream(input, config, **kwargs): + runnable, config = self._prepare(config) + async for chunk in runnable.astream(input, config, **kwargs): yield chunk def transform( @@ -191,7 +198,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: - return self._prepare(config).transform(input, config, **kwargs) + runnable, config = self._prepare(config) + return runnable.transform(input, config, **kwargs) async def atransform( self, @@ -199,7 +207,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Output]: - async for chunk in self._prepare(config).atransform(input, config, **kwargs): + runnable, config = self._prepare(config) + async for chunk in runnable.atransform(input, config, **kwargs): yield chunk @@ -238,7 +247,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): def _prepare( self, config: Optional[RunnableConfig] = None - ) -> Runnable[Input, Output]: + ) -> Tuple[Runnable[Input, Output], RunnableConfig]: config = config or {} specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()} configurable_fields = { @@ -266,9 +275,12 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): } if configurable: - return self.default.__class__(**{**self.default.__dict__, **configurable}) + return ( + self.default.__class__(**{**self.default.__dict__, **configurable}), + config, + ) else: - return self.default + return (self.default, config) # Before Python 3.11 native StrEnum is not available @@ -363,21 +375,39 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): def _prepare( self, config: Optional[RunnableConfig] = None - ) -> Runnable[Input, Output]: + ) -> Tuple[Runnable[Input, Output], RunnableConfig]: config = config or {} which = config.get("configurable", {}).get(self.which.id, self.default_key) + # remap configurable keys for the chosen alternative + if self.prefix_keys: + config = cast( + RunnableConfig, + { + **config, + "configurable": { + _strremoveprefix(k, f"{self.which.id}=={which}/"): v + for k, v in config.get("configurable", {}).items() + }, + }, + ) + # return the chosen alternative if which == self.default_key: - return self.default + return (self.default, config) elif which in self.alternatives: alt = self.alternatives[which] if isinstance(alt, Runnable): - return alt + return (alt, config) else: - return alt() + return (alt(), config) else: raise ValueError(f"Unknown alternative: {which}") +def _strremoveprefix(s: str, prefix: str) -> str: + """str.removeprefix() is only available in Python 3.9+.""" + return s.replace(prefix, "", 1) if s.startswith(prefix) else s + + def prefix_config_spec( spec: ConfigurableFieldSpec, prefix: str ) -> ConfigurableFieldSpec: diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 4e04b37a0d3..f9eb3bdebc5 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -121,6 +121,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): ] ] = None + def __repr_args__(self) -> Any: + # Without this repr(self) raises a RecursionError + # See https://github.com/pydantic/pydantic/issues/7327 + return [] + def __init__( self, func: Optional[ @@ -175,7 +180,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]], ], ], - ) -> RunnableAssign: + ) -> "RunnableAssign": """Merge the Dict input with the output produced by the mapping argument. Args: diff --git a/libs/core/tests/unit_tests/messages/test_imports.py b/libs/core/tests/unit_tests/messages/test_imports.py index 539c6bebfc7..dba0c840600 100644 --- a/libs/core/tests/unit_tests/messages/test_imports.py +++ b/libs/core/tests/unit_tests/messages/test_imports.py @@ -17,6 +17,7 @@ EXPECTED_ALL = [ "ToolMessage", "ToolMessageChunk", "get_buffer_string", + "message_chunk_to_message", "messages_from_dict", "messages_to_dict", "message_to_dict", diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index ef1cfa8fb14..cac2f5d2329 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -14,6 +14,7 @@ from langchain_core.messages import ( SystemMessage, ToolMessage, get_buffer_string, + message_chunk_to_message, messages_from_dict, messages_to_dict, ) @@ -184,3 +185,18 @@ def test_multiple_msg() -> None: sys_msg, ] assert messages_from_dict(messages_to_dict(msgs)) == msgs + + +def test_message_chunk_to_message() -> None: + assert message_chunk_to_message( + AIMessageChunk(content="I am", additional_kwargs={"foo": "bar"}) + ) == AIMessage(content="I am", additional_kwargs={"foo": "bar"}) + assert message_chunk_to_message(HumanMessageChunk(content="I am")) == HumanMessage( + content="I am" + ) + assert message_chunk_to_message( + ChatMessageChunk(role="User", content="I am") + ) == ChatMessage(role="User", content="I am") + assert message_chunk_to_message( + FunctionMessageChunk(name="hello", content="I am") + ) == FunctionMessage(name="hello", content="I am") diff --git a/libs/langchain/langchain/chat_models/openai.py b/libs/langchain/langchain/chat_models/openai.py index 46de2d70f91..74c24066cb8 100644 --- a/libs/langchain/langchain/chat_models/openai.py +++ b/libs/langchain/langchain/chat_models/openai.py @@ -418,7 +418,11 @@ class ChatOpenAI(BaseChatModel): ) return generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs} + params = { + **params, + **({"stream": stream} if stream is not None else {}), + **kwargs, + } response = self.completion_with_retry( messages=message_dicts, run_manager=run_manager, **params ) @@ -502,7 +506,11 @@ class ChatOpenAI(BaseChatModel): return await agenerate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs} + params = { + **params, + **({"stream": stream} if stream is not None else {}), + **kwargs, + } response = await acompletion_with_retry( self, messages=message_dicts, run_manager=run_manager, **params )