mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-19 17:36:00 +00:00
rfc
This commit is contained in:
@@ -103,12 +103,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessageChunk:
|
||||
_config: Dict[str, Any] = dict(config or {})
|
||||
_config.pop("_locals", None)
|
||||
return cast(
|
||||
BaseMessageChunk,
|
||||
cast(
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
[self._convert_input(input)], stop=stop, **_config, **kwargs
|
||||
).generations[0][0],
|
||||
).message,
|
||||
)
|
||||
@@ -127,8 +129,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||
)
|
||||
|
||||
_config: Dict[str, Any] = dict(config or {})
|
||||
_config.pop("_locals", None)
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
[self._convert_input(input)], stop=stop, **_config, **kwargs
|
||||
)
|
||||
return cast(
|
||||
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
|
@@ -219,13 +219,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return (
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
)
|
||||
.generations[0][0]
|
||||
.text
|
||||
_config: Dict[str, Any] = dict(config or {})
|
||||
_config.pop("_locals", None)
|
||||
result = self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **_config, **kwargs
|
||||
)
|
||||
return result.generations[0][0].text
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@@ -241,8 +240,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||
)
|
||||
|
||||
_config: Dict[str, Any] = dict(config or {})
|
||||
_config.pop("_locals", None)
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
[self._convert_input(input)], stop=stop, **_config, **kwargs
|
||||
)
|
||||
return llm_result.generations[0][0].text
|
||||
|
||||
|
@@ -107,7 +107,9 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
) -> List[Document]:
|
||||
return self.get_relevant_documents(input, **(config or {}))
|
||||
_config: Dict[str, Any] = dict(config or {})
|
||||
_config.pop("_locals", None)
|
||||
return self.get_relevant_documents(input, **_config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
@@ -116,7 +118,9 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
# If the retriever doesn't implement async, use default implementation
|
||||
return await super().ainvoke(input, config)
|
||||
|
||||
return await self.aget_relevant_documents(input, **(config or {}))
|
||||
_config: Dict[str, Any] = dict(config or {})
|
||||
_config.pop("_locals", None)
|
||||
return await self.aget_relevant_documents(input, **_config)
|
||||
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
|
@@ -674,6 +674,46 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
raise first_error
|
||||
|
||||
|
||||
class PutLocalVar(Serializable, Runnable[Input, Input]):
|
||||
key: Union[str, Dict[str, str]]
|
||||
|
||||
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||
super().__init__(key=key, **kwargs)
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
"PutLocalVar should only be used in a RunnableSequence, and should "
|
||||
"therefore always receive a non-null config."
|
||||
)
|
||||
if isinstance(self.key, str):
|
||||
config["_locals"][self.key] = input
|
||||
else:
|
||||
if not isinstance(input, Mapping):
|
||||
raise ValueError
|
||||
for get_key, put_key in self.key.items():
|
||||
config["_locals"][put_key] = input[get_key]
|
||||
return self._call_with_config(lambda x: x, input, config)
|
||||
|
||||
|
||||
class GetLocalVar(Serializable, Runnable[str, Any]):
|
||||
key: str
|
||||
passthrough_key: Optional[str] = None
|
||||
|
||||
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||
super().__init__(key=key, **kwargs)
|
||||
|
||||
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
"PutLocalVar should only be used in a RunnableSequence, and should "
|
||||
"therefore always receive a non-null config."
|
||||
)
|
||||
if self.passthrough_key is not None:
|
||||
return {self.key: config["_locals"][self.key], self.passthrough_key: input}
|
||||
return config["_locals"][self.key]
|
||||
|
||||
|
||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A sequence of runnables, where the output of each is the input of the next.
|
||||
@@ -749,11 +789,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
try:
|
||||
callbacks = run_manager.get_child()
|
||||
for step in self.steps:
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, callbacks),
|
||||
)
|
||||
# mark each step as child run
|
||||
step_config = _patch_config(config, callbacks)
|
||||
input = step.invoke(input, step_config)
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
@@ -1401,11 +1439,14 @@ class RouterRunnable(
|
||||
|
||||
|
||||
def _patch_config(
|
||||
config: RunnableConfig, callback_manager: BaseCallbackManager, _locals: Optional[Dict[str, Any]] = None
|
||||
config: RunnableConfig,
|
||||
callback_manager: BaseCallbackManager,
|
||||
_locals: Optional[Dict[str, Any]] = None,
|
||||
) -> RunnableConfig:
|
||||
config = deepcopy(config)
|
||||
config = config.copy()
|
||||
config["callbacks"] = callback_manager
|
||||
config["_locals"] = _locals or {}
|
||||
if _locals is not None:
|
||||
config["_locals"] = _locals
|
||||
return config
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user