mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 12:58:59 +00:00
Fixes for opengpts release (#13960)
This commit is contained in:
parent
947daaf833
commit
970fe23feb
@ -34,6 +34,7 @@ from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
@ -63,7 +64,14 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(
|
||||
message=message_chunk_to_message(generation.message),
|
||||
generation_info=generation.generation_info,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def agenerate_from_stream(
|
||||
@ -76,7 +84,14 @@ async def agenerate_from_stream(
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(
|
||||
message=message_chunk_to_message(generation.message),
|
||||
generation_info=generation.generation_info,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
@ -98,6 +98,15 @@ def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]:
|
||||
return [_message_from_dict(m) for m in messages]
|
||||
|
||||
|
||||
def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
|
||||
if not isinstance(chunk, BaseMessageChunk):
|
||||
return chunk
|
||||
# chunk classes always have the equivalent non-chunk class as their first parent
|
||||
return chunk.__class__.__mro__[1](
|
||||
**{k: v for k, v in chunk.__dict__.items() if k != "type"}
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
@ -115,6 +124,7 @@ __all__ = [
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"get_buffer_string",
|
||||
"message_chunk_to_message",
|
||||
"messages_from_dict",
|
||||
"messages_to_dict",
|
||||
"message_to_dict",
|
||||
|
@ -12,6 +12,7 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
@ -65,28 +66,32 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self._prepare(config).get_input_schema(config)
|
||||
runnable, config = self._prepare(config)
|
||||
return runnable.get_input_schema(config)
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self._prepare(config).get_output_schema(config)
|
||||
runnable, config = self._prepare(config)
|
||||
return runnable.get_output_schema(config)
|
||||
|
||||
@abstractmethod
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Runnable[Input, Output]:
|
||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
...
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
return self._prepare(config).invoke(input, config, **kwargs)
|
||||
runnable, config = self._prepare(config)
|
||||
return runnable.invoke(input, config, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
return await self._prepare(config).ainvoke(input, config, **kwargs)
|
||||
runnable, config = self._prepare(config)
|
||||
return await runnable.ainvoke(input, config, **kwargs)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
@ -99,21 +104,22 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
configs = get_config_list(config, len(inputs))
|
||||
prepared = [self._prepare(c) for c in configs]
|
||||
|
||||
if all(p is self.default for p in prepared):
|
||||
if all(p is self.default for p, _ in prepared):
|
||||
return self.default.batch(
|
||||
inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
inputs,
|
||||
[c for _, c in prepared],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
def invoke(
|
||||
bound: Runnable[Input, Output],
|
||||
prepared: Tuple[Runnable[Input, Output], RunnableConfig],
|
||||
input: Input,
|
||||
config: RunnableConfig,
|
||||
) -> Union[Output, Exception]:
|
||||
bound, config = prepared
|
||||
if return_exceptions:
|
||||
try:
|
||||
return bound.invoke(input, config, **kwargs)
|
||||
@ -124,12 +130,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
if len(inputs) == 1:
|
||||
return cast(List[Output], [invoke(prepared[0], inputs[0], configs[0])])
|
||||
return cast(List[Output], [invoke(prepared[0], inputs[0])])
|
||||
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
return cast(
|
||||
List[Output], list(executor.map(invoke, prepared, inputs, configs))
|
||||
)
|
||||
return cast(List[Output], list(executor.map(invoke, prepared, inputs)))
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
@ -142,21 +146,22 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
configs = get_config_list(config, len(inputs))
|
||||
prepared = [self._prepare(c) for c in configs]
|
||||
|
||||
if all(p is self.default for p in prepared):
|
||||
if all(p is self.default for p, _ in prepared):
|
||||
return await self.default.abatch(
|
||||
inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
inputs,
|
||||
[c for _, c in prepared],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
async def ainvoke(
|
||||
bound: Runnable[Input, Output],
|
||||
prepared: Tuple[Runnable[Input, Output], RunnableConfig],
|
||||
input: Input,
|
||||
config: RunnableConfig,
|
||||
) -> Union[Output, Exception]:
|
||||
bound, config = prepared
|
||||
if return_exceptions:
|
||||
try:
|
||||
return await bound.ainvoke(input, config, **kwargs)
|
||||
@ -165,7 +170,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
else:
|
||||
return await bound.ainvoke(input, config, **kwargs)
|
||||
|
||||
coros = map(ainvoke, prepared, inputs, configs)
|
||||
coros = map(ainvoke, prepared, inputs)
|
||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||
|
||||
def stream(
|
||||
@ -174,7 +179,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
return self._prepare(config).stream(input, config, **kwargs)
|
||||
runnable, config = self._prepare(config)
|
||||
return runnable.stream(input, config, **kwargs)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
@ -182,7 +188,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async for chunk in self._prepare(config).astream(input, config, **kwargs):
|
||||
runnable, config = self._prepare(config)
|
||||
async for chunk in runnable.astream(input, config, **kwargs):
|
||||
yield chunk
|
||||
|
||||
def transform(
|
||||
@ -191,7 +198,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
return self._prepare(config).transform(input, config, **kwargs)
|
||||
runnable, config = self._prepare(config)
|
||||
return runnable.transform(input, config, **kwargs)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
@ -199,7 +207,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async for chunk in self._prepare(config).atransform(input, config, **kwargs):
|
||||
runnable, config = self._prepare(config)
|
||||
async for chunk in runnable.atransform(input, config, **kwargs):
|
||||
yield chunk
|
||||
|
||||
|
||||
@ -238,7 +247,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Runnable[Input, Output]:
|
||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
config = config or {}
|
||||
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
|
||||
configurable_fields = {
|
||||
@ -266,9 +275,12 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
}
|
||||
|
||||
if configurable:
|
||||
return self.default.__class__(**{**self.default.__dict__, **configurable})
|
||||
return (
|
||||
self.default.__class__(**{**self.default.__dict__, **configurable}),
|
||||
config,
|
||||
)
|
||||
else:
|
||||
return self.default
|
||||
return (self.default, config)
|
||||
|
||||
|
||||
# Before Python 3.11 native StrEnum is not available
|
||||
@ -363,21 +375,39 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
|
||||
def _prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Runnable[Input, Output]:
|
||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
config = config or {}
|
||||
which = config.get("configurable", {}).get(self.which.id, self.default_key)
|
||||
# remap configurable keys for the chosen alternative
|
||||
if self.prefix_keys:
|
||||
config = cast(
|
||||
RunnableConfig,
|
||||
{
|
||||
**config,
|
||||
"configurable": {
|
||||
_strremoveprefix(k, f"{self.which.id}=={which}/"): v
|
||||
for k, v in config.get("configurable", {}).items()
|
||||
},
|
||||
},
|
||||
)
|
||||
# return the chosen alternative
|
||||
if which == self.default_key:
|
||||
return self.default
|
||||
return (self.default, config)
|
||||
elif which in self.alternatives:
|
||||
alt = self.alternatives[which]
|
||||
if isinstance(alt, Runnable):
|
||||
return alt
|
||||
return (alt, config)
|
||||
else:
|
||||
return alt()
|
||||
return (alt(), config)
|
||||
else:
|
||||
raise ValueError(f"Unknown alternative: {which}")
|
||||
|
||||
|
||||
def _strremoveprefix(s: str, prefix: str) -> str:
|
||||
"""str.removeprefix() is only available in Python 3.9+."""
|
||||
return s.replace(prefix, "", 1) if s.startswith(prefix) else s
|
||||
|
||||
|
||||
def prefix_config_spec(
|
||||
spec: ConfigurableFieldSpec, prefix: str
|
||||
) -> ConfigurableFieldSpec:
|
||||
|
@ -121,6 +121,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
]
|
||||
] = None
|
||||
|
||||
def __repr_args__(self) -> Any:
|
||||
# Without this repr(self) raises a RecursionError
|
||||
# See https://github.com/pydantic/pydantic/issues/7327
|
||||
return []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Optional[
|
||||
@ -175,7 +180,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
|
||||
],
|
||||
],
|
||||
) -> RunnableAssign:
|
||||
) -> "RunnableAssign":
|
||||
"""Merge the Dict input with the output produced by the mapping argument.
|
||||
|
||||
Args:
|
||||
|
@ -17,6 +17,7 @@ EXPECTED_ALL = [
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"get_buffer_string",
|
||||
"message_chunk_to_message",
|
||||
"messages_from_dict",
|
||||
"messages_to_dict",
|
||||
"message_to_dict",
|
||||
|
@ -14,6 +14,7 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
get_buffer_string,
|
||||
message_chunk_to_message,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
@ -184,3 +185,18 @@ def test_multiple_msg() -> None:
|
||||
sys_msg,
|
||||
]
|
||||
assert messages_from_dict(messages_to_dict(msgs)) == msgs
|
||||
|
||||
|
||||
def test_message_chunk_to_message() -> None:
|
||||
assert message_chunk_to_message(
|
||||
AIMessageChunk(content="I am", additional_kwargs={"foo": "bar"})
|
||||
) == AIMessage(content="I am", additional_kwargs={"foo": "bar"})
|
||||
assert message_chunk_to_message(HumanMessageChunk(content="I am")) == HumanMessage(
|
||||
content="I am"
|
||||
)
|
||||
assert message_chunk_to_message(
|
||||
ChatMessageChunk(role="User", content="I am")
|
||||
) == ChatMessage(role="User", content="I am")
|
||||
assert message_chunk_to_message(
|
||||
FunctionMessageChunk(name="hello", content="I am")
|
||||
) == FunctionMessage(name="hello", content="I am")
|
||||
|
@ -418,7 +418,11 @@ class ChatOpenAI(BaseChatModel):
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
params = {
|
||||
**params,
|
||||
**({"stream": stream} if stream is not None else {}),
|
||||
**kwargs,
|
||||
}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
@ -502,7 +506,11 @@ class ChatOpenAI(BaseChatModel):
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
params = {
|
||||
**params,
|
||||
**({"stream": stream} if stream is not None else {}),
|
||||
**kwargs,
|
||||
}
|
||||
response = await acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user