Fixes for opengpts release (#13960)

This commit is contained in:
Nuno Campos 2023-11-28 21:49:43 +00:00 committed by GitHub
parent 947daaf833
commit 970fe23feb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 123 additions and 38 deletions

View File

@ -34,6 +34,7 @@ from langchain_core.messages import (
BaseMessage, BaseMessage,
BaseMessageChunk, BaseMessageChunk,
HumanMessage, HumanMessage,
message_chunk_to_message,
) )
from langchain_core.outputs import ( from langchain_core.outputs import (
ChatGeneration, ChatGeneration,
@ -63,7 +64,14 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
else: else:
generation += chunk generation += chunk
assert generation is not None assert generation is not None
return ChatResult(generations=[generation]) return ChatResult(
generations=[
ChatGeneration(
message=message_chunk_to_message(generation.message),
generation_info=generation.generation_info,
)
]
)
async def agenerate_from_stream( async def agenerate_from_stream(
@ -76,7 +84,14 @@ async def agenerate_from_stream(
else: else:
generation += chunk generation += chunk
assert generation is not None assert generation is not None
return ChatResult(generations=[generation]) return ChatResult(
generations=[
ChatGeneration(
message=message_chunk_to_message(generation.message),
generation_info=generation.generation_info,
)
]
)
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):

View File

@ -98,6 +98,15 @@ def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages] return [_message_from_dict(m) for m in messages]
def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
if not isinstance(chunk, BaseMessageChunk):
return chunk
# chunk classes always have the equivalent non-chunk class as their first parent
return chunk.__class__.__mro__[1](
**{k: v for k, v in chunk.__dict__.items() if k != "type"}
)
__all__ = [ __all__ = [
"AIMessage", "AIMessage",
"AIMessageChunk", "AIMessageChunk",
@ -115,6 +124,7 @@ __all__ = [
"ToolMessage", "ToolMessage",
"ToolMessageChunk", "ToolMessageChunk",
"get_buffer_string", "get_buffer_string",
"message_chunk_to_message",
"messages_from_dict", "messages_from_dict",
"messages_to_dict", "messages_to_dict",
"message_to_dict", "message_to_dict",

View File

@ -12,6 +12,7 @@ from typing import (
List, List,
Optional, Optional,
Sequence, Sequence,
Tuple,
Type, Type,
Union, Union,
cast, cast,
@ -65,28 +66,32 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> Type[BaseModel]:
return self._prepare(config).get_input_schema(config) runnable, config = self._prepare(config)
return runnable.get_input_schema(config)
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> Type[BaseModel]:
return self._prepare(config).get_output_schema(config) runnable, config = self._prepare(config)
return runnable.get_output_schema(config)
@abstractmethod @abstractmethod
def _prepare( def _prepare(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]: ) -> Tuple[Runnable[Input, Output], RunnableConfig]:
... ...
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
return self._prepare(config).invoke(input, config, **kwargs) runnable, config = self._prepare(config)
return runnable.invoke(input, config, **kwargs)
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
return await self._prepare(config).ainvoke(input, config, **kwargs) runnable, config = self._prepare(config)
return await runnable.ainvoke(input, config, **kwargs)
def batch( def batch(
self, self,
@ -99,21 +104,22 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs] prepared = [self._prepare(c) for c in configs]
if all(p is self.default for p in prepared): if all(p is self.default for p, _ in prepared):
return self.default.batch( return self.default.batch(
inputs, config, return_exceptions=return_exceptions, **kwargs inputs,
[c for _, c in prepared],
return_exceptions=return_exceptions,
**kwargs,
) )
if not inputs: if not inputs:
return [] return []
configs = get_config_list(config, len(inputs))
def invoke( def invoke(
bound: Runnable[Input, Output], prepared: Tuple[Runnable[Input, Output], RunnableConfig],
input: Input, input: Input,
config: RunnableConfig,
) -> Union[Output, Exception]: ) -> Union[Output, Exception]:
bound, config = prepared
if return_exceptions: if return_exceptions:
try: try:
return bound.invoke(input, config, **kwargs) return bound.invoke(input, config, **kwargs)
@ -124,12 +130,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
# If there's only one input, don't bother with the executor # If there's only one input, don't bother with the executor
if len(inputs) == 1: if len(inputs) == 1:
return cast(List[Output], [invoke(prepared[0], inputs[0], configs[0])]) return cast(List[Output], [invoke(prepared[0], inputs[0])])
with get_executor_for_config(configs[0]) as executor: with get_executor_for_config(configs[0]) as executor:
return cast( return cast(List[Output], list(executor.map(invoke, prepared, inputs)))
List[Output], list(executor.map(invoke, prepared, inputs, configs))
)
async def abatch( async def abatch(
self, self,
@ -142,21 +146,22 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs] prepared = [self._prepare(c) for c in configs]
if all(p is self.default for p in prepared): if all(p is self.default for p, _ in prepared):
return await self.default.abatch( return await self.default.abatch(
inputs, config, return_exceptions=return_exceptions, **kwargs inputs,
[c for _, c in prepared],
return_exceptions=return_exceptions,
**kwargs,
) )
if not inputs: if not inputs:
return [] return []
configs = get_config_list(config, len(inputs))
async def ainvoke( async def ainvoke(
bound: Runnable[Input, Output], prepared: Tuple[Runnable[Input, Output], RunnableConfig],
input: Input, input: Input,
config: RunnableConfig,
) -> Union[Output, Exception]: ) -> Union[Output, Exception]:
bound, config = prepared
if return_exceptions: if return_exceptions:
try: try:
return await bound.ainvoke(input, config, **kwargs) return await bound.ainvoke(input, config, **kwargs)
@ -165,7 +170,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
else: else:
return await bound.ainvoke(input, config, **kwargs) return await bound.ainvoke(input, config, **kwargs)
coros = map(ainvoke, prepared, inputs, configs) coros = map(ainvoke, prepared, inputs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
def stream( def stream(
@ -174,7 +179,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
return self._prepare(config).stream(input, config, **kwargs) runnable, config = self._prepare(config)
return runnable.stream(input, config, **kwargs)
async def astream( async def astream(
self, self,
@ -182,7 +188,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for chunk in self._prepare(config).astream(input, config, **kwargs): runnable, config = self._prepare(config)
async for chunk in runnable.astream(input, config, **kwargs):
yield chunk yield chunk
def transform( def transform(
@ -191,7 +198,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
return self._prepare(config).transform(input, config, **kwargs) runnable, config = self._prepare(config)
return runnable.transform(input, config, **kwargs)
async def atransform( async def atransform(
self, self,
@ -199,7 +207,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for chunk in self._prepare(config).atransform(input, config, **kwargs): runnable, config = self._prepare(config)
async for chunk in runnable.atransform(input, config, **kwargs):
yield chunk yield chunk
@ -238,7 +247,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
def _prepare( def _prepare(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]: ) -> Tuple[Runnable[Input, Output], RunnableConfig]:
config = config or {} config = config or {}
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()} specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
configurable_fields = { configurable_fields = {
@ -266,9 +275,12 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
} }
if configurable: if configurable:
return self.default.__class__(**{**self.default.__dict__, **configurable}) return (
self.default.__class__(**{**self.default.__dict__, **configurable}),
config,
)
else: else:
return self.default return (self.default, config)
# Before Python 3.11 native StrEnum is not available # Before Python 3.11 native StrEnum is not available
@ -363,21 +375,39 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
def _prepare( def _prepare(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]: ) -> Tuple[Runnable[Input, Output], RunnableConfig]:
config = config or {} config = config or {}
which = config.get("configurable", {}).get(self.which.id, self.default_key) which = config.get("configurable", {}).get(self.which.id, self.default_key)
# remap configurable keys for the chosen alternative
if self.prefix_keys:
config = cast(
RunnableConfig,
{
**config,
"configurable": {
_strremoveprefix(k, f"{self.which.id}=={which}/"): v
for k, v in config.get("configurable", {}).items()
},
},
)
# return the chosen alternative
if which == self.default_key: if which == self.default_key:
return self.default return (self.default, config)
elif which in self.alternatives: elif which in self.alternatives:
alt = self.alternatives[which] alt = self.alternatives[which]
if isinstance(alt, Runnable): if isinstance(alt, Runnable):
return alt return (alt, config)
else: else:
return alt() return (alt(), config)
else: else:
raise ValueError(f"Unknown alternative: {which}") raise ValueError(f"Unknown alternative: {which}")
def _strremoveprefix(s: str, prefix: str) -> str:
"""str.removeprefix() is only available in Python 3.9+."""
return s.replace(prefix, "", 1) if s.startswith(prefix) else s
def prefix_config_spec( def prefix_config_spec(
spec: ConfigurableFieldSpec, prefix: str spec: ConfigurableFieldSpec, prefix: str
) -> ConfigurableFieldSpec: ) -> ConfigurableFieldSpec:

View File

@ -121,6 +121,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
] ]
] = None ] = None
def __repr_args__(self) -> Any:
# Without this repr(self) raises a RecursionError
# See https://github.com/pydantic/pydantic/issues/7327
return []
def __init__( def __init__(
self, self,
func: Optional[ func: Optional[
@ -175,7 +180,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]], Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
], ],
], ],
) -> RunnableAssign: ) -> "RunnableAssign":
"""Merge the Dict input with the output produced by the mapping argument. """Merge the Dict input with the output produced by the mapping argument.
Args: Args:

View File

@ -17,6 +17,7 @@ EXPECTED_ALL = [
"ToolMessage", "ToolMessage",
"ToolMessageChunk", "ToolMessageChunk",
"get_buffer_string", "get_buffer_string",
"message_chunk_to_message",
"messages_from_dict", "messages_from_dict",
"messages_to_dict", "messages_to_dict",
"message_to_dict", "message_to_dict",

View File

@ -14,6 +14,7 @@ from langchain_core.messages import (
SystemMessage, SystemMessage,
ToolMessage, ToolMessage,
get_buffer_string, get_buffer_string,
message_chunk_to_message,
messages_from_dict, messages_from_dict,
messages_to_dict, messages_to_dict,
) )
@ -184,3 +185,18 @@ def test_multiple_msg() -> None:
sys_msg, sys_msg,
] ]
assert messages_from_dict(messages_to_dict(msgs)) == msgs assert messages_from_dict(messages_to_dict(msgs)) == msgs
def test_message_chunk_to_message() -> None:
assert message_chunk_to_message(
AIMessageChunk(content="I am", additional_kwargs={"foo": "bar"})
) == AIMessage(content="I am", additional_kwargs={"foo": "bar"})
assert message_chunk_to_message(HumanMessageChunk(content="I am")) == HumanMessage(
content="I am"
)
assert message_chunk_to_message(
ChatMessageChunk(role="User", content="I am")
) == ChatMessage(role="User", content="I am")
assert message_chunk_to_message(
FunctionMessageChunk(name="hello", content="I am")
) == FunctionMessage(name="hello", content="I am")

View File

@ -418,7 +418,11 @@ class ChatOpenAI(BaseChatModel):
) )
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = self.completion_with_retry( response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params messages=message_dicts, run_manager=run_manager, **params
) )
@ -502,7 +506,11 @@ class ChatOpenAI(BaseChatModel):
return await agenerate_from_stream(stream_iter) return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = await acompletion_with_retry( response = await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params self, messages=message_dicts, run_manager=run_manager, **params
) )