From a7274f006e9885bbf56fd50d8f66ebb4dcf9807e Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Tue, 26 Mar 2024 21:57:13 +0100 Subject: [PATCH] langchain[patch]: Add async methods to VectorstoreIndexCreator (#19582) --- .../langchain/indexes/vectorstore.py | 52 ++++++++++++++++++- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/indexes/vectorstore.py b/libs/langchain/langchain/indexes/vectorstore.py index 5889fcb311b..b70cf33f0f7 100644 --- a/libs/langchain/langchain/indexes/vectorstore.py +++ b/libs/langchain/langchain/indexes/vectorstore.py @@ -43,7 +43,22 @@ class VectorStoreIndexWrapper(BaseModel): chain = RetrievalQA.from_chain_type( llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs ) - return chain.run(question) + return chain.invoke({chain.input_key: question})[chain.output_key] + + async def aquery( + self, + question: str, + llm: Optional[BaseLanguageModel] = None, + retriever_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> str: + """Query the vectorstore.""" + llm = llm or OpenAI(temperature=0) + retriever_kwargs = retriever_kwargs or {} + chain = RetrievalQA.from_chain_type( + llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs + ) + return (await chain.ainvoke({chain.input_key: question}))[chain.output_key] def query_with_sources( self, @@ -58,7 +73,22 @@ class VectorStoreIndexWrapper(BaseModel): chain = RetrievalQAWithSourcesChain.from_chain_type( llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs ) - return chain({chain.question_key: question}) + return chain.invoke({chain.question_key: question}) + + async def aquery_with_sources( + self, + question: str, + llm: Optional[BaseLanguageModel] = None, + retriever_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> dict: + """Query the vectorstore and get back sources.""" + llm = llm or OpenAI(temperature=0) + retriever_kwargs = retriever_kwargs or {} + chain = RetrievalQAWithSourcesChain.from_chain_type( + llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs + ) + return await chain.ainvoke({chain.question_key: question}) class VectorstoreIndexCreator(BaseModel): @@ -82,6 +112,14 @@ class VectorstoreIndexCreator(BaseModel): docs.extend(loader.load()) return self.from_documents(docs) + async def afrom_loaders(self, loaders: List[BaseLoader]) -> VectorStoreIndexWrapper: + """Create a vectorstore index from loaders.""" + docs = [] + for loader in loaders: + async for doc in loader.alazy_load(): + docs.append(doc) + return await self.afrom_documents(docs) + def from_documents(self, documents: List[Document]) -> VectorStoreIndexWrapper: """Create a vectorstore index from documents.""" sub_docs = self.text_splitter.split_documents(documents) @@ -89,3 +127,13 @@ class VectorstoreIndexCreator(BaseModel): sub_docs, self.embedding, **self.vectorstore_kwargs ) return VectorStoreIndexWrapper(vectorstore=vectorstore) + + async def afrom_documents( + self, documents: List[Document] + ) -> VectorStoreIndexWrapper: + """Create a vectorstore index from documents.""" + sub_docs = self.text_splitter.split_documents(documents) + vectorstore = await self.vectorstore_cls.afrom_documents( + sub_docs, self.embedding, **self.vectorstore_kwargs + ) + return VectorStoreIndexWrapper(vectorstore=vectorstore)