mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-19 17:36:00 +00:00
rfc
This commit is contained in:
@@ -103,12 +103,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessageChunk:
|
||||||
|
_config: Dict[str, Any] = dict(config or {})
|
||||||
|
_config.pop("_locals", None)
|
||||||
return cast(
|
return cast(
|
||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
cast(
|
cast(
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
self.generate_prompt(
|
self.generate_prompt(
|
||||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
[self._convert_input(input)], stop=stop, **_config, **kwargs
|
||||||
).generations[0][0],
|
).generations[0][0],
|
||||||
).message,
|
).message,
|
||||||
)
|
)
|
||||||
@@ -127,8 +129,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_config: Dict[str, Any] = dict(config or {})
|
||||||
|
_config.pop("_locals", None)
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
[self._convert_input(input)], stop=stop, **_config, **kwargs
|
||||||
)
|
)
|
||||||
return cast(
|
return cast(
|
||||||
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||||
|
@@ -219,13 +219,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
return (
|
_config: Dict[str, Any] = dict(config or {})
|
||||||
self.generate_prompt(
|
_config.pop("_locals", None)
|
||||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
result = self.generate_prompt(
|
||||||
)
|
[self._convert_input(input)], stop=stop, **_config, **kwargs
|
||||||
.generations[0][0]
|
|
||||||
.text
|
|
||||||
)
|
)
|
||||||
|
return result.generations[0][0].text
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
@@ -241,8 +240,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_config: Dict[str, Any] = dict(config or {})
|
||||||
|
_config.pop("_locals", None)
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
[self._convert_input(input)], stop=stop, **_config, **kwargs
|
||||||
)
|
)
|
||||||
return llm_result.generations[0][0].text
|
return llm_result.generations[0][0].text
|
||||||
|
|
||||||
|
@@ -107,7 +107,9 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
|||||||
def invoke(
|
def invoke(
|
||||||
self, input: str, config: Optional[RunnableConfig] = None
|
self, input: str, config: Optional[RunnableConfig] = None
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
return self.get_relevant_documents(input, **(config or {}))
|
_config: Dict[str, Any] = dict(config or {})
|
||||||
|
_config.pop("_locals", None)
|
||||||
|
return self.get_relevant_documents(input, **_config)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: str, config: Optional[RunnableConfig] = None
|
self, input: str, config: Optional[RunnableConfig] = None
|
||||||
@@ -116,7 +118,9 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
|||||||
# If the retriever doesn't implement async, use default implementation
|
# If the retriever doesn't implement async, use default implementation
|
||||||
return await super().ainvoke(input, config)
|
return await super().ainvoke(input, config)
|
||||||
|
|
||||||
return await self.aget_relevant_documents(input, **(config or {}))
|
_config: Dict[str, Any] = dict(config or {})
|
||||||
|
_config.pop("_locals", None)
|
||||||
|
return await self.aget_relevant_documents(input, **_config)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
|
@@ -674,6 +674,46 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
raise first_error
|
raise first_error
|
||||||
|
|
||||||
|
|
||||||
|
class PutLocalVar(Serializable, Runnable[Input, Input]):
|
||||||
|
key: Union[str, Dict[str, str]]
|
||||||
|
|
||||||
|
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||||
|
super().__init__(key=key, **kwargs)
|
||||||
|
|
||||||
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||||
|
if config is None:
|
||||||
|
raise ValueError(
|
||||||
|
"PutLocalVar should only be used in a RunnableSequence, and should "
|
||||||
|
"therefore always receive a non-null config."
|
||||||
|
)
|
||||||
|
if isinstance(self.key, str):
|
||||||
|
config["_locals"][self.key] = input
|
||||||
|
else:
|
||||||
|
if not isinstance(input, Mapping):
|
||||||
|
raise ValueError
|
||||||
|
for get_key, put_key in self.key.items():
|
||||||
|
config["_locals"][put_key] = input[get_key]
|
||||||
|
return self._call_with_config(lambda x: x, input, config)
|
||||||
|
|
||||||
|
|
||||||
|
class GetLocalVar(Serializable, Runnable[str, Any]):
|
||||||
|
key: str
|
||||||
|
passthrough_key: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||||
|
super().__init__(key=key, **kwargs)
|
||||||
|
|
||||||
|
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
|
||||||
|
if config is None:
|
||||||
|
raise ValueError(
|
||||||
|
"PutLocalVar should only be used in a RunnableSequence, and should "
|
||||||
|
"therefore always receive a non-null config."
|
||||||
|
)
|
||||||
|
if self.passthrough_key is not None:
|
||||||
|
return {self.key: config["_locals"][self.key], self.passthrough_key: input}
|
||||||
|
return config["_locals"][self.key]
|
||||||
|
|
||||||
|
|
||||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
A sequence of runnables, where the output of each is the input of the next.
|
A sequence of runnables, where the output of each is the input of the next.
|
||||||
@@ -749,11 +789,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
callbacks = run_manager.get_child()
|
callbacks = run_manager.get_child()
|
||||||
for step in self.steps:
|
for step in self.steps:
|
||||||
input = step.invoke(
|
# mark each step as child run
|
||||||
input,
|
step_config = _patch_config(config, callbacks)
|
||||||
# mark each step as a child run
|
input = step.invoke(input, step_config)
|
||||||
_patch_config(config, callbacks),
|
|
||||||
)
|
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
@@ -1401,11 +1439,14 @@ class RouterRunnable(
|
|||||||
|
|
||||||
|
|
||||||
def _patch_config(
|
def _patch_config(
|
||||||
config: RunnableConfig, callback_manager: BaseCallbackManager, _locals: Optional[Dict[str, Any]] = None
|
config: RunnableConfig,
|
||||||
|
callback_manager: BaseCallbackManager,
|
||||||
|
_locals: Optional[Dict[str, Any]] = None,
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
config = deepcopy(config)
|
config = config.copy()
|
||||||
config["callbacks"] = callback_manager
|
config["callbacks"] = callback_manager
|
||||||
config["_locals"] = _locals or {}
|
if _locals is not None:
|
||||||
|
config["_locals"] = _locals
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user