mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 20:41:52 +00:00
fix
This commit is contained in:
parent
c447e9a854
commit
6b0a849f59
@ -7,10 +7,13 @@ from langchain.schema.runnable.base import (
|
|||||||
RunnableWithFallbacks,
|
RunnableWithFallbacks,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.config import RunnableConfig
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
|
from langchain.schema.runnable.locals import GetLocalVar, PutLocalVar
|
||||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"GetLocalVar",
|
||||||
|
"PutLocalVar",
|
||||||
"RouterInput",
|
"RouterInput",
|
||||||
"RouterRunnable",
|
"RouterRunnable",
|
||||||
"Runnable",
|
"Runnable",
|
||||||
|
@ -238,7 +238,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
return (
|
return (
|
||||||
config
|
config
|
||||||
if isinstance(config, list)
|
if isinstance(config, list)
|
||||||
else [deepcopy(config) if config is not None else {} for _ in range(length)]
|
else [
|
||||||
|
deepcopy(config) if config is not None else _empty_config()
|
||||||
|
for _ in range(length)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _call_with_config(
|
def _call_with_config(
|
||||||
@ -750,7 +753,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Output:
|
) -> Output:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or _empty_config()
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = _get_async_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
@ -896,7 +899,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or _empty_config()
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = _get_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
@ -963,7 +966,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = config or _empty_config()
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = _get_async_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
@ -3,8 +3,9 @@ from __future__ import annotations
|
|||||||
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
|
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
|
from langchain.schema.runnable.base import Input, Output, Runnable
|
||||||
from langchain.schema.runnable.base import Input, Output
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
|
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||||
|
|
||||||
|
|
||||||
class PutLocalVar(RunnablePassthrough):
|
class PutLocalVar(RunnablePassthrough):
|
||||||
@ -27,7 +28,12 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
)
|
)
|
||||||
if isinstance(self.key, str):
|
if isinstance(self.key, str):
|
||||||
config["_locals"][self.key] = input
|
config["_locals"][self.key] = input
|
||||||
elif isinstance(input, Mapping):
|
elif isinstance(self.key, Mapping):
|
||||||
|
if not isinstance(input, Mapping):
|
||||||
|
raise TypeError(
|
||||||
|
f"Received key of type Mapping but input of type {type(input)}. "
|
||||||
|
f"input is expected to be of type Mapping when key is Mapping."
|
||||||
|
)
|
||||||
for input_key, put_key in self.key.items():
|
for input_key, put_key in self.key.items():
|
||||||
config["_locals"][put_key] = input[input_key]
|
config["_locals"][put_key] = input[input_key]
|
||||||
else:
|
else:
|
||||||
@ -44,17 +50,23 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
"PutLocalVar should only be used in a RunnableSequence, and should "
|
"PutLocalVar should only be used in a RunnableSequence, and should "
|
||||||
"therefore always receive a non-null config."
|
"therefore always receive a non-null config."
|
||||||
)
|
)
|
||||||
|
print(config)
|
||||||
if isinstance(self.key, str):
|
if isinstance(self.key, str):
|
||||||
if self.key not in config["_locals"]:
|
if self.key not in config["_locals"]:
|
||||||
config["_locals"][self.key] = input
|
config["_locals"][self.key] = input
|
||||||
else:
|
else:
|
||||||
config["_locals"][self.key] += input
|
config["_locals"][self.key] += input
|
||||||
elif isinstance(input, Mapping):
|
elif isinstance(self.key, Mapping):
|
||||||
|
if not isinstance(input, Mapping):
|
||||||
|
raise TypeError(
|
||||||
|
f"Received key of type Mapping but input of type {type(input)}. "
|
||||||
|
f"input is expected to be of type Mapping when key is Mapping."
|
||||||
|
)
|
||||||
for input_key, put_key in self.key.items():
|
for input_key, put_key in self.key.items():
|
||||||
if put_key not in config["_locals"]:
|
if put_key not in config["_locals"]:
|
||||||
config["_locals"][put_key] = input
|
config["_locals"][put_key] = input[input_key]
|
||||||
else:
|
else:
|
||||||
config["_locals"][put_key] += input
|
config["_locals"][put_key] += input[input_key]
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"`key` should be a string or Mapping[str, str], received type "
|
f"`key` should be a string or Mapping[str, str], received type "
|
||||||
@ -75,14 +87,14 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
self, input: Iterator[Input], config: RunnableConfig | None = None
|
self, input: Iterator[Input], config: RunnableConfig | None = None
|
||||||
) -> Iterator[Input]:
|
) -> Iterator[Input]:
|
||||||
for chunk in super().transform(input, config=config):
|
for chunk in super().transform(input, config=config):
|
||||||
self._concat_put(input, config=config)
|
self._concat_put(chunk, config=config)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
|
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
|
||||||
) -> AsyncIterator[Input]:
|
) -> AsyncIterator[Input]:
|
||||||
async for chunk in super().atransform(input, config=config):
|
async for chunk in super().atransform(input, config=config):
|
||||||
self._concat_put(input, config=config)
|
self._concat_put(chunk, config=config)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user