mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
Add async support to SelfQueryRetriever (#10175)
### Description SelfQueryRetriever is missing async support, so I am adding it. I also removed deprecated predict_and_parse method usage here, and added some tests. ### Issue N/A ### Tag maintainer Not yet ### Twitter handle N/A
This commit is contained in:
parent
35297ca0d3
commit
fd9da60aea
@ -160,4 +160,6 @@ def load_query_constructor_chain(
|
|||||||
allowed_operators=allowed_operators,
|
allowed_operators=allowed_operators,
|
||||||
enable_limit=enable_limit,
|
enable_limit=enable_limit,
|
||||||
)
|
)
|
||||||
return LLMChain(llm=llm, prompt=prompt, **kwargs)
|
return LLMChain(
|
||||||
|
llm=llm, prompt=prompt, output_parser=prompt.output_parser, **kwargs
|
||||||
|
)
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
"""Retriever that generates and executes structured queries over its own data source."""
|
"""Retriever that generates and executes structured queries over its own data source."""
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, cast
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Type, cast
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
CallbackManagerForRetrieverRun,
|
||||||
|
)
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.chains.query_constructor.base import load_query_constructor_chain
|
from langchain.chains.query_constructor.base import load_query_constructor_chain
|
||||||
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
|
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
|
||||||
@ -42,6 +45,8 @@ from langchain.vectorstores import (
|
|||||||
Weaviate,
|
Weaviate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
|
def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
|
||||||
"""Get the translator class corresponding to the vector store class."""
|
"""Get the translator class corresponding to the vector store class."""
|
||||||
@ -108,6 +113,49 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
def _get_structured_query(
|
||||||
|
self, inputs: Dict[str, Any], run_manager: CallbackManagerForRetrieverRun
|
||||||
|
) -> StructuredQuery:
|
||||||
|
structured_query = cast(
|
||||||
|
StructuredQuery,
|
||||||
|
self.llm_chain.predict(callbacks=run_manager.get_child(), **inputs),
|
||||||
|
)
|
||||||
|
return structured_query
|
||||||
|
|
||||||
|
async def _aget_structured_query(
|
||||||
|
self, inputs: Dict[str, Any], run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
|
) -> StructuredQuery:
|
||||||
|
structured_query = cast(
|
||||||
|
StructuredQuery,
|
||||||
|
await self.llm_chain.apredict(callbacks=run_manager.get_child(), **inputs),
|
||||||
|
)
|
||||||
|
return structured_query
|
||||||
|
|
||||||
|
def _prepare_query(
|
||||||
|
self, query: str, structured_query: StructuredQuery
|
||||||
|
) -> Tuple[str, Dict[str, Any]]:
|
||||||
|
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
|
||||||
|
structured_query
|
||||||
|
)
|
||||||
|
if structured_query.limit is not None:
|
||||||
|
new_kwargs["k"] = structured_query.limit
|
||||||
|
if self.use_original_query:
|
||||||
|
new_query = query
|
||||||
|
search_kwargs = {**self.search_kwargs, **new_kwargs}
|
||||||
|
return new_query, search_kwargs
|
||||||
|
|
||||||
|
def _get_docs_with_query(
|
||||||
|
self, query: str, search_kwargs: Dict[str, Any]
|
||||||
|
) -> List[Document]:
|
||||||
|
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
async def _aget_docs_with_query(
|
||||||
|
self, query: str, search_kwargs: Dict[str, Any]
|
||||||
|
) -> List[Document]:
|
||||||
|
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
|
||||||
|
return docs
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
@ -120,25 +168,30 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
List of relevant documents
|
List of relevant documents
|
||||||
"""
|
"""
|
||||||
inputs = self.llm_chain.prep_inputs({"query": query})
|
inputs = self.llm_chain.prep_inputs({"query": query})
|
||||||
structured_query = cast(
|
structured_query = self._get_structured_query(inputs, run_manager)
|
||||||
StructuredQuery,
|
|
||||||
self.llm_chain.predict_and_parse(
|
|
||||||
callbacks=run_manager.get_child(), **inputs
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(structured_query)
|
logger.info(f"Generated Query: {structured_query}")
|
||||||
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
|
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||||
structured_query
|
docs = self._get_docs_with_query(new_query, search_kwargs)
|
||||||
)
|
return docs
|
||||||
if structured_query.limit is not None:
|
|
||||||
new_kwargs["k"] = structured_query.limit
|
|
||||||
|
|
||||||
if self.use_original_query:
|
async def _aget_relevant_documents(
|
||||||
new_query = query
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Get documents relevant for a query.
|
||||||
|
|
||||||
search_kwargs = {**self.search_kwargs, **new_kwargs}
|
Args:
|
||||||
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
|
query: string to find relevant documents for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of relevant documents
|
||||||
|
"""
|
||||||
|
inputs = self.llm_chain.prep_inputs({"query": query})
|
||||||
|
structured_query = await self._aget_structured_query(inputs, run_manager)
|
||||||
|
if self.verbose:
|
||||||
|
logger.info(f"Generated Query: {structured_query}")
|
||||||
|
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||||
|
docs = await self._aget_docs_with_query(new_query, search_kwargs)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -0,0 +1,142 @@
|
|||||||
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
|
CallbackManagerForRetrieverRun,
|
||||||
|
)
|
||||||
|
from langchain.chains.query_constructor.ir import (
|
||||||
|
Comparator,
|
||||||
|
Comparison,
|
||||||
|
Operation,
|
||||||
|
Operator,
|
||||||
|
StructuredQuery,
|
||||||
|
Visitor,
|
||||||
|
)
|
||||||
|
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||||
|
from langchain.retrievers import SelfQueryRetriever
|
||||||
|
from langchain.schema import Document
|
||||||
|
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTranslator(Visitor):
|
||||||
|
allowed_comparators = (
|
||||||
|
Comparator.EQ,
|
||||||
|
Comparator.NE,
|
||||||
|
Comparator.LT,
|
||||||
|
Comparator.LTE,
|
||||||
|
Comparator.GT,
|
||||||
|
Comparator.GTE,
|
||||||
|
Comparator.CONTAIN,
|
||||||
|
Comparator.LIKE,
|
||||||
|
)
|
||||||
|
allowed_operators = (Operator.AND, Operator.OR, Operator.NOT)
|
||||||
|
|
||||||
|
def _format_func(self, func: Union[Operator, Comparator]) -> str:
|
||||||
|
self._validate_func(func)
|
||||||
|
return f"${func.value}"
|
||||||
|
|
||||||
|
def visit_operation(self, operation: Operation) -> Dict:
|
||||||
|
args = [arg.accept(self) for arg in operation.arguments]
|
||||||
|
return {self._format_func(operation.operator): args}
|
||||||
|
|
||||||
|
def visit_comparison(self, comparison: Comparison) -> Dict:
|
||||||
|
return {
|
||||||
|
comparison.attribute: {
|
||||||
|
self._format_func(comparison.comparator): comparison.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def visit_structured_query(
|
||||||
|
self, structured_query: StructuredQuery
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
if structured_query.filter is None:
|
||||||
|
kwargs = {}
|
||||||
|
else:
|
||||||
|
kwargs = {"filter": structured_query.filter.accept(self)}
|
||||||
|
return structured_query.query, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||||
|
def similarity_search(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
res = self.store.get(query)
|
||||||
|
if res is None:
|
||||||
|
return []
|
||||||
|
return [res]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def fake_llm() -> FakeLLM:
|
||||||
|
return FakeLLM(
|
||||||
|
queries={
|
||||||
|
"1": """```json
|
||||||
|
{
|
||||||
|
"query": "test",
|
||||||
|
"filter": null
|
||||||
|
}
|
||||||
|
```""",
|
||||||
|
"bar": "baz",
|
||||||
|
},
|
||||||
|
sequential_responses=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def fake_vectorstore() -> InMemoryVectorstoreWithSearch:
|
||||||
|
vectorstore = InMemoryVectorstoreWithSearch()
|
||||||
|
vectorstore.add_documents(
|
||||||
|
[
|
||||||
|
Document(
|
||||||
|
page_content="test",
|
||||||
|
metadata={
|
||||||
|
"foo": "bar",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
ids=["test"],
|
||||||
|
)
|
||||||
|
return vectorstore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def fake_self_query_retriever(
|
||||||
|
fake_llm: FakeLLM, fake_vectorstore: InMemoryVectorstoreWithSearch
|
||||||
|
) -> SelfQueryRetriever:
|
||||||
|
return SelfQueryRetriever.from_llm(
|
||||||
|
llm=fake_llm,
|
||||||
|
vectorstore=fake_vectorstore,
|
||||||
|
document_contents="test",
|
||||||
|
metadata_field_info=[
|
||||||
|
AttributeInfo(
|
||||||
|
name="foo",
|
||||||
|
type="string",
|
||||||
|
description="test",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
structured_query_translator=FakeTranslator(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test__get_relevant_documents(fake_self_query_retriever: SelfQueryRetriever) -> None:
|
||||||
|
relevant_documents = fake_self_query_retriever._get_relevant_documents(
|
||||||
|
"foo",
|
||||||
|
run_manager=CallbackManagerForRetrieverRun.get_noop_manager(),
|
||||||
|
)
|
||||||
|
assert len(relevant_documents) == 1
|
||||||
|
assert relevant_documents[0].metadata["foo"] == "bar"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test__aget_relevant_documents(
|
||||||
|
fake_self_query_retriever: SelfQueryRetriever,
|
||||||
|
) -> None:
|
||||||
|
relevant_documents = await fake_self_query_retriever._aget_relevant_documents(
|
||||||
|
"foo",
|
||||||
|
run_manager=AsyncCallbackManagerForRetrieverRun.get_noop_manager(),
|
||||||
|
)
|
||||||
|
assert len(relevant_documents) == 1
|
||||||
|
assert relevant_documents[0].metadata["foo"] == "bar"
|
Loading…
Reference in New Issue
Block a user