This commit is contained in:
Bagatur 2023-08-21 18:01:49 -07:00
parent a9bf409a09
commit 4e7e6bfe0a
4 changed files with 56 additions and 34 deletions

View File

@ -63,10 +63,13 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
config = config or {} config = config or {}
config_kwargs: Dict = { return self(
k: config.get(k) for k in ("callbacks", "tags", "metadata") input,
} callbacks=config.get("callbacks"),
return self(input, **config_kwargs, **kwargs) tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
async def ainvoke( async def ainvoke(
self, self,
@ -79,11 +82,15 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs) None, partial(self.invoke, input, config, **kwargs)
) )
config = config or {} config = config or {}
config_kwargs: Dict = { return await self.acall(
k: config.get(k) for k in ("callbacks", "tags", "metadata") input,
} callbacks=config.get("callbacks"),
return await self.acall(input, **config_kwargs, **kwargs) tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
memory: Optional[BaseMemory] = None memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None. """Optional memory object. Defaults to None.

View File

@ -105,15 +105,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
**kwargs: Any, **kwargs: Any,
) -> BaseMessageChunk: ) -> BaseMessageChunk:
config = config or {} config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return cast( return cast(
BaseMessageChunk, BaseMessageChunk,
cast( cast(
ChatGeneration, ChatGeneration,
self.generate_prompt( self.generate_prompt(
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs [self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
).generations[0][0], ).generations[0][0],
).message, ).message,
) )
@ -133,11 +135,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) )
config = config or {} config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs [self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
) )
return cast( return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message

View File

@ -220,13 +220,18 @@ class BaseLLM(BaseLanguageModel[str], ABC):
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
config = config or {} config = config or {}
config_kwargs: Dict = { return (
k: config.get(k) for k in ("callbacks", "tags", "metadata") self.generate_prompt(
} [self._convert_input(input)],
result = self.generate_prompt( stop=stop,
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
.generations[0][0]
.text
) )
return result.generations[0][0].text
async def ainvoke( async def ainvoke(
self, self,
@ -243,11 +248,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
) )
config = config or {} config = config or {}
config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs [self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
) )
return llm_result.generations[0][0].text return llm_result.generations[0][0].text

View File

@ -108,10 +108,12 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
self, input: str, config: Optional[RunnableConfig] = None self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]: ) -> List[Document]:
config = config or {} config = config or {}
config_kwargs: Dict = { return self.get_relevant_documents(
k: config.get(k) for k in ("callbacks", "tags", "metadata") input,
} callbacks=config.get("callbacks"),
return self.get_relevant_documents(input, **config_kwargs) tags=config.get("tags"),
metadata=config.get("metadata"),
)
async def ainvoke( async def ainvoke(
self, self,
@ -124,10 +126,12 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
return await super().ainvoke(input, config) return await super().ainvoke(input, config)
config = config or {} config = config or {}
config_kwargs: Dict = { return await self.aget_relevant_documents(
k: config.get(k) for k in ("callbacks", "tags", "metadata") input,
} callbacks=config.get("callbacks"),
return await self.aget_relevant_documents(input, **config_kwargs) tags=config.get("tags"),
metadata=config.get("metadata"),
)
@abstractmethod @abstractmethod
def _get_relevant_documents( def _get_relevant_documents(