mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
revert
This commit is contained in:
parent
a9bf409a09
commit
4e7e6bfe0a
@ -63,10 +63,13 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
config = config or {}
|
||||
config_kwargs: Dict = {
|
||||
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||
}
|
||||
return self(input, **config_kwargs, **kwargs)
|
||||
return self(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@ -79,11 +82,15 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, **kwargs)
|
||||
)
|
||||
|
||||
config = config or {}
|
||||
config_kwargs: Dict = {
|
||||
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||
}
|
||||
return await self.acall(input, **config_kwargs, **kwargs)
|
||||
return await self.acall(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
"""Optional memory object. Defaults to None.
|
||||
|
@ -105,15 +105,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
**kwargs: Any,
|
||||
) -> BaseMessageChunk:
|
||||
config = config or {}
|
||||
config_kwargs: Dict = {
|
||||
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||
}
|
||||
return cast(
|
||||
BaseMessageChunk,
|
||||
cast(
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
).generations[0][0],
|
||||
).message,
|
||||
)
|
||||
@ -133,11 +135,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
)
|
||||
|
||||
config = config or {}
|
||||
config_kwargs: Dict = {
|
||||
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||
}
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
return cast(
|
||||
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
|
@ -220,13 +220,18 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
config = config or {}
|
||||
config_kwargs: Dict = {
|
||||
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||
}
|
||||
result = self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
|
||||
return (
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
.generations[0][0]
|
||||
.text
|
||||
)
|
||||
return result.generations[0][0].text
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@ -243,11 +248,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
)
|
||||
|
||||
config = config or {}
|
||||
config_kwargs: Dict = {
|
||||
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||
}
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
return llm_result.generations[0][0].text
|
||||
|
||||
|
@ -108,10 +108,12 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
) -> List[Document]:
|
||||
config = config or {}
|
||||
config_kwargs: Dict = {
|
||||
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||
}
|
||||
return self.get_relevant_documents(input, **config_kwargs)
|
||||
return self.get_relevant_documents(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@ -124,10 +126,12 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
return await super().ainvoke(input, config)
|
||||
|
||||
config = config or {}
|
||||
config_kwargs: Dict = {
|
||||
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||
}
|
||||
return await self.aget_relevant_documents(input, **config_kwargs)
|
||||
return await self.aget_relevant_documents(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
|
Loading…
Reference in New Issue
Block a user