This commit is contained in:
Bagatur
2023-08-09 14:13:06 -07:00
parent 50b13ab938
commit eb0134fbb3
4 changed files with 69 additions and 19 deletions

View File

@@ -103,12 +103,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessageChunk: ) -> BaseMessageChunk:
_config: Dict[str, Any] = dict(config or {})
_config.pop("_locals", None)
return cast( return cast(
BaseMessageChunk, BaseMessageChunk,
cast( cast(
ChatGeneration, ChatGeneration,
self.generate_prompt( self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs [self._convert_input(input)], stop=stop, **_config, **kwargs
).generations[0][0], ).generations[0][0],
).message, ).message,
) )
@@ -127,8 +129,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
None, partial(self.invoke, input, config, stop=stop, **kwargs) None, partial(self.invoke, input, config, stop=stop, **kwargs)
) )
_config: Dict[str, Any] = dict(config or {})
_config.pop("_locals", None)
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs [self._convert_input(input)], stop=stop, **_config, **kwargs
) )
return cast( return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message

View File

@@ -219,13 +219,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
return ( _config: Dict[str, Any] = dict(config or {})
self.generate_prompt( _config.pop("_locals", None)
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs result = self.generate_prompt(
) [self._convert_input(input)], stop=stop, **_config, **kwargs
.generations[0][0]
.text
) )
return result.generations[0][0].text
async def ainvoke( async def ainvoke(
self, self,
@@ -241,8 +240,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
None, partial(self.invoke, input, config, stop=stop, **kwargs) None, partial(self.invoke, input, config, stop=stop, **kwargs)
) )
_config: Dict[str, Any] = dict(config or {})
_config.pop("_locals", None)
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs [self._convert_input(input)], stop=stop, **_config, **kwargs
) )
return llm_result.generations[0][0].text return llm_result.generations[0][0].text

View File

@@ -107,7 +107,9 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
def invoke( def invoke(
self, input: str, config: Optional[RunnableConfig] = None self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]: ) -> List[Document]:
return self.get_relevant_documents(input, **(config or {})) _config: Dict[str, Any] = dict(config or {})
_config.pop("_locals", None)
return self.get_relevant_documents(input, **_config)
async def ainvoke( async def ainvoke(
self, input: str, config: Optional[RunnableConfig] = None self, input: str, config: Optional[RunnableConfig] = None
@@ -116,7 +118,9 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
# If the retriever doesn't implement async, use default implementation # If the retriever doesn't implement async, use default implementation
return await super().ainvoke(input, config) return await super().ainvoke(input, config)
return await self.aget_relevant_documents(input, **(config or {})) _config: Dict[str, Any] = dict(config or {})
_config.pop("_locals", None)
return await self.aget_relevant_documents(input, **_config)
@abstractmethod @abstractmethod
def _get_relevant_documents( def _get_relevant_documents(

View File

@@ -674,6 +674,46 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
raise first_error raise first_error
class PutLocalVar(Serializable, Runnable[Input, Input]):
key: Union[str, Dict[str, str]]
def __init__(self, key: str, **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
if config is None:
raise ValueError(
"PutLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
if isinstance(self.key, str):
config["_locals"][self.key] = input
else:
if not isinstance(input, Mapping):
raise ValueError
for get_key, put_key in self.key.items():
config["_locals"][put_key] = input[get_key]
return self._call_with_config(lambda x: x, input, config)
class GetLocalVar(Serializable, Runnable[str, Any]):
key: str
passthrough_key: Optional[str] = None
def __init__(self, key: str, **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
if config is None:
raise ValueError(
"PutLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
if self.passthrough_key is not None:
return {self.key: config["_locals"][self.key], self.passthrough_key: input}
return config["_locals"][self.key]
class RunnableSequence(Serializable, Runnable[Input, Output]): class RunnableSequence(Serializable, Runnable[Input, Output]):
""" """
A sequence of runnables, where the output of each is the input of the next. A sequence of runnables, where the output of each is the input of the next.
@@ -749,11 +789,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try: try:
callbacks = run_manager.get_child() callbacks = run_manager.get_child()
for step in self.steps: for step in self.steps:
input = step.invoke( # mark each step as child run
input, step_config = _patch_config(config, callbacks)
# mark each step as a child run input = step.invoke(input, step_config)
_patch_config(config, callbacks),
)
# finish the root run # finish the root run
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
@@ -1401,11 +1439,14 @@ class RouterRunnable(
def _patch_config( def _patch_config(
config: RunnableConfig, callback_manager: BaseCallbackManager, _locals: Optional[Dict[str, Any]] = None config: RunnableConfig,
callback_manager: BaseCallbackManager,
_locals: Optional[Dict[str, Any]] = None,
) -> RunnableConfig: ) -> RunnableConfig:
config = deepcopy(config) config = config.copy()
config["callbacks"] = callback_manager config["callbacks"] = callback_manager
config["_locals"] = _locals or {} if _locals is not None:
config["_locals"] = _locals
return config return config