mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
Fixes for opengpts release (#13960)
This commit is contained in:
parent
947daaf833
commit
970fe23feb
@ -34,6 +34,7 @@ from langchain_core.messages import (
|
|||||||
BaseMessage,
|
BaseMessage,
|
||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
|
message_chunk_to_message,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
@ -63,7 +64,14 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
|||||||
else:
|
else:
|
||||||
generation += chunk
|
generation += chunk
|
||||||
assert generation is not None
|
assert generation is not None
|
||||||
return ChatResult(generations=[generation])
|
return ChatResult(
|
||||||
|
generations=[
|
||||||
|
ChatGeneration(
|
||||||
|
message=message_chunk_to_message(generation.message),
|
||||||
|
generation_info=generation.generation_info,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def agenerate_from_stream(
|
async def agenerate_from_stream(
|
||||||
@ -76,7 +84,14 @@ async def agenerate_from_stream(
|
|||||||
else:
|
else:
|
||||||
generation += chunk
|
generation += chunk
|
||||||
assert generation is not None
|
assert generation is not None
|
||||||
return ChatResult(generations=[generation])
|
return ChatResult(
|
||||||
|
generations=[
|
||||||
|
ChatGeneration(
|
||||||
|
message=message_chunk_to_message(generation.message),
|
||||||
|
generation_info=generation.generation_info,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||||
|
@ -98,6 +98,15 @@ def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]:
|
|||||||
return [_message_from_dict(m) for m in messages]
|
return [_message_from_dict(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
|
def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
|
||||||
|
if not isinstance(chunk, BaseMessageChunk):
|
||||||
|
return chunk
|
||||||
|
# chunk classes always have the equivalent non-chunk class as their first parent
|
||||||
|
return chunk.__class__.__mro__[1](
|
||||||
|
**{k: v for k, v in chunk.__dict__.items() if k != "type"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AIMessage",
|
"AIMessage",
|
||||||
"AIMessageChunk",
|
"AIMessageChunk",
|
||||||
@ -115,6 +124,7 @@ __all__ = [
|
|||||||
"ToolMessage",
|
"ToolMessage",
|
||||||
"ToolMessageChunk",
|
"ToolMessageChunk",
|
||||||
"get_buffer_string",
|
"get_buffer_string",
|
||||||
|
"message_chunk_to_message",
|
||||||
"messages_from_dict",
|
"messages_from_dict",
|
||||||
"messages_to_dict",
|
"messages_to_dict",
|
||||||
"message_to_dict",
|
"message_to_dict",
|
||||||
|
@ -12,6 +12,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -65,28 +66,32 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
return self._prepare(config).get_input_schema(config)
|
runnable, config = self._prepare(config)
|
||||||
|
return runnable.get_input_schema(config)
|
||||||
|
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
return self._prepare(config).get_output_schema(config)
|
runnable, config = self._prepare(config)
|
||||||
|
return runnable.get_output_schema(config)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _prepare(
|
def _prepare(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Runnable[Input, Output]:
|
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||||
...
|
...
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
return self._prepare(config).invoke(input, config, **kwargs)
|
runnable, config = self._prepare(config)
|
||||||
|
return runnable.invoke(input, config, **kwargs)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
return await self._prepare(config).ainvoke(input, config, **kwargs)
|
runnable, config = self._prepare(config)
|
||||||
|
return await runnable.ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
def batch(
|
def batch(
|
||||||
self,
|
self,
|
||||||
@ -99,21 +104,22 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
prepared = [self._prepare(c) for c in configs]
|
prepared = [self._prepare(c) for c in configs]
|
||||||
|
|
||||||
if all(p is self.default for p in prepared):
|
if all(p is self.default for p, _ in prepared):
|
||||||
return self.default.batch(
|
return self.default.batch(
|
||||||
inputs, config, return_exceptions=return_exceptions, **kwargs
|
inputs,
|
||||||
|
[c for _, c in prepared],
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not inputs:
|
if not inputs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
configs = get_config_list(config, len(inputs))
|
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
bound: Runnable[Input, Output],
|
prepared: Tuple[Runnable[Input, Output], RunnableConfig],
|
||||||
input: Input,
|
input: Input,
|
||||||
config: RunnableConfig,
|
|
||||||
) -> Union[Output, Exception]:
|
) -> Union[Output, Exception]:
|
||||||
|
bound, config = prepared
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
try:
|
try:
|
||||||
return bound.invoke(input, config, **kwargs)
|
return bound.invoke(input, config, **kwargs)
|
||||||
@ -124,12 +130,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
# If there's only one input, don't bother with the executor
|
# If there's only one input, don't bother with the executor
|
||||||
if len(inputs) == 1:
|
if len(inputs) == 1:
|
||||||
return cast(List[Output], [invoke(prepared[0], inputs[0], configs[0])])
|
return cast(List[Output], [invoke(prepared[0], inputs[0])])
|
||||||
|
|
||||||
with get_executor_for_config(configs[0]) as executor:
|
with get_executor_for_config(configs[0]) as executor:
|
||||||
return cast(
|
return cast(List[Output], list(executor.map(invoke, prepared, inputs)))
|
||||||
List[Output], list(executor.map(invoke, prepared, inputs, configs))
|
|
||||||
)
|
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
@ -142,21 +146,22 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
prepared = [self._prepare(c) for c in configs]
|
prepared = [self._prepare(c) for c in configs]
|
||||||
|
|
||||||
if all(p is self.default for p in prepared):
|
if all(p is self.default for p, _ in prepared):
|
||||||
return await self.default.abatch(
|
return await self.default.abatch(
|
||||||
inputs, config, return_exceptions=return_exceptions, **kwargs
|
inputs,
|
||||||
|
[c for _, c in prepared],
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not inputs:
|
if not inputs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
configs = get_config_list(config, len(inputs))
|
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
bound: Runnable[Input, Output],
|
prepared: Tuple[Runnable[Input, Output], RunnableConfig],
|
||||||
input: Input,
|
input: Input,
|
||||||
config: RunnableConfig,
|
|
||||||
) -> Union[Output, Exception]:
|
) -> Union[Output, Exception]:
|
||||||
|
bound, config = prepared
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
try:
|
try:
|
||||||
return await bound.ainvoke(input, config, **kwargs)
|
return await bound.ainvoke(input, config, **kwargs)
|
||||||
@ -165,7 +170,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
else:
|
else:
|
||||||
return await bound.ainvoke(input, config, **kwargs)
|
return await bound.ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
coros = map(ainvoke, prepared, inputs, configs)
|
coros = map(ainvoke, prepared, inputs)
|
||||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
@ -174,7 +179,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
return self._prepare(config).stream(input, config, **kwargs)
|
runnable, config = self._prepare(config)
|
||||||
|
return runnable.stream(input, config, **kwargs)
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
@ -182,7 +188,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
async for chunk in self._prepare(config).astream(input, config, **kwargs):
|
runnable, config = self._prepare(config)
|
||||||
|
async for chunk in runnable.astream(input, config, **kwargs):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
@ -191,7 +198,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
return self._prepare(config).transform(input, config, **kwargs)
|
runnable, config = self._prepare(config)
|
||||||
|
return runnable.transform(input, config, **kwargs)
|
||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
@ -199,7 +207,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
async for chunk in self._prepare(config).atransform(input, config, **kwargs):
|
runnable, config = self._prepare(config)
|
||||||
|
async for chunk in runnable.atransform(input, config, **kwargs):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
@ -238,7 +247,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|||||||
|
|
||||||
def _prepare(
|
def _prepare(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Runnable[Input, Output]:
|
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
|
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
|
||||||
configurable_fields = {
|
configurable_fields = {
|
||||||
@ -266,9 +275,12 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if configurable:
|
if configurable:
|
||||||
return self.default.__class__(**{**self.default.__dict__, **configurable})
|
return (
|
||||||
|
self.default.__class__(**{**self.default.__dict__, **configurable}),
|
||||||
|
config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self.default
|
return (self.default, config)
|
||||||
|
|
||||||
|
|
||||||
# Before Python 3.11 native StrEnum is not available
|
# Before Python 3.11 native StrEnum is not available
|
||||||
@ -363,21 +375,39 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|||||||
|
|
||||||
def _prepare(
|
def _prepare(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Runnable[Input, Output]:
|
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
which = config.get("configurable", {}).get(self.which.id, self.default_key)
|
which = config.get("configurable", {}).get(self.which.id, self.default_key)
|
||||||
|
# remap configurable keys for the chosen alternative
|
||||||
|
if self.prefix_keys:
|
||||||
|
config = cast(
|
||||||
|
RunnableConfig,
|
||||||
|
{
|
||||||
|
**config,
|
||||||
|
"configurable": {
|
||||||
|
_strremoveprefix(k, f"{self.which.id}=={which}/"): v
|
||||||
|
for k, v in config.get("configurable", {}).items()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# return the chosen alternative
|
||||||
if which == self.default_key:
|
if which == self.default_key:
|
||||||
return self.default
|
return (self.default, config)
|
||||||
elif which in self.alternatives:
|
elif which in self.alternatives:
|
||||||
alt = self.alternatives[which]
|
alt = self.alternatives[which]
|
||||||
if isinstance(alt, Runnable):
|
if isinstance(alt, Runnable):
|
||||||
return alt
|
return (alt, config)
|
||||||
else:
|
else:
|
||||||
return alt()
|
return (alt(), config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown alternative: {which}")
|
raise ValueError(f"Unknown alternative: {which}")
|
||||||
|
|
||||||
|
|
||||||
|
def _strremoveprefix(s: str, prefix: str) -> str:
|
||||||
|
"""str.removeprefix() is only available in Python 3.9+."""
|
||||||
|
return s.replace(prefix, "", 1) if s.startswith(prefix) else s
|
||||||
|
|
||||||
|
|
||||||
def prefix_config_spec(
|
def prefix_config_spec(
|
||||||
spec: ConfigurableFieldSpec, prefix: str
|
spec: ConfigurableFieldSpec, prefix: str
|
||||||
) -> ConfigurableFieldSpec:
|
) -> ConfigurableFieldSpec:
|
||||||
|
@ -121,6 +121,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
|||||||
]
|
]
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
|
def __repr_args__(self) -> Any:
|
||||||
|
# Without this repr(self) raises a RecursionError
|
||||||
|
# See https://github.com/pydantic/pydantic/issues/7327
|
||||||
|
return []
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
func: Optional[
|
func: Optional[
|
||||||
@ -175,7 +180,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
|||||||
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
|
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
) -> RunnableAssign:
|
) -> "RunnableAssign":
|
||||||
"""Merge the Dict input with the output produced by the mapping argument.
|
"""Merge the Dict input with the output produced by the mapping argument.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -17,6 +17,7 @@ EXPECTED_ALL = [
|
|||||||
"ToolMessage",
|
"ToolMessage",
|
||||||
"ToolMessageChunk",
|
"ToolMessageChunk",
|
||||||
"get_buffer_string",
|
"get_buffer_string",
|
||||||
|
"message_chunk_to_message",
|
||||||
"messages_from_dict",
|
"messages_from_dict",
|
||||||
"messages_to_dict",
|
"messages_to_dict",
|
||||||
"message_to_dict",
|
"message_to_dict",
|
||||||
|
@ -14,6 +14,7 @@ from langchain_core.messages import (
|
|||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
|
message_chunk_to_message,
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
messages_to_dict,
|
messages_to_dict,
|
||||||
)
|
)
|
||||||
@ -184,3 +185,18 @@ def test_multiple_msg() -> None:
|
|||||||
sys_msg,
|
sys_msg,
|
||||||
]
|
]
|
||||||
assert messages_from_dict(messages_to_dict(msgs)) == msgs
|
assert messages_from_dict(messages_to_dict(msgs)) == msgs
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_chunk_to_message() -> None:
|
||||||
|
assert message_chunk_to_message(
|
||||||
|
AIMessageChunk(content="I am", additional_kwargs={"foo": "bar"})
|
||||||
|
) == AIMessage(content="I am", additional_kwargs={"foo": "bar"})
|
||||||
|
assert message_chunk_to_message(HumanMessageChunk(content="I am")) == HumanMessage(
|
||||||
|
content="I am"
|
||||||
|
)
|
||||||
|
assert message_chunk_to_message(
|
||||||
|
ChatMessageChunk(role="User", content="I am")
|
||||||
|
) == ChatMessage(role="User", content="I am")
|
||||||
|
assert message_chunk_to_message(
|
||||||
|
FunctionMessageChunk(name="hello", content="I am")
|
||||||
|
) == FunctionMessage(name="hello", content="I am")
|
||||||
|
@ -418,7 +418,11 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return generate_from_stream(stream_iter)
|
return generate_from_stream(stream_iter)
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs}
|
params = {
|
||||||
|
**params,
|
||||||
|
**({"stream": stream} if stream is not None else {}),
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
response = self.completion_with_retry(
|
response = self.completion_with_retry(
|
||||||
messages=message_dicts, run_manager=run_manager, **params
|
messages=message_dicts, run_manager=run_manager, **params
|
||||||
)
|
)
|
||||||
@ -502,7 +506,11 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
return await agenerate_from_stream(stream_iter)
|
return await agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs}
|
params = {
|
||||||
|
**params,
|
||||||
|
**({"stream": stream} if stream is not None else {}),
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
response = await acompletion_with_retry(
|
response = await acompletion_with_retry(
|
||||||
self, messages=message_dicts, run_manager=run_manager, **params
|
self, messages=message_dicts, run_manager=run_manager, **params
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user