langchain/libs/community/langchain_community/retrievers/llama_index.py
Erick Friis c2a3021bb0
multiple: pydantic 2 compatibility, v0.3 (#26443)
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com>
Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com>
Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com>
Co-authored-by: ZhangShenao <15201440436@163.com>
Co-authored-by: Friso H. Kingma <fhkingma@gmail.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
Co-authored-by: Morgante Pell <morgantep@google.com>
2024-09-13 14:38:45 -07:00

87 lines
3.1 KiB
Python

from typing import Any, Dict, List, cast
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import Field
class LlamaIndexRetriever(BaseRetriever):
"""`LlamaIndex` retriever.
It is used for the question-answering with sources over
an LlamaIndex data structure."""
index: Any = None
"""LlamaIndex index to query."""
query_kwargs: Dict = Field(default_factory=dict)
"""Keyword arguments to pass to the query method."""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
try:
from llama_index.core.base.response.schema import Response
from llama_index.core.indices.base import BaseGPTIndex
except ImportError:
raise ImportError(
"You need to install `pip install llama-index` to use this retriever."
)
index = cast(BaseGPTIndex, self.index)
response = index.query(query, **self.query_kwargs)
response = cast(Response, response)
# parse source nodes
docs = []
for source_node in response.source_nodes:
metadata = source_node.metadata or {}
docs.append(
Document(page_content=source_node.get_content(), metadata=metadata)
)
return docs
class LlamaIndexGraphRetriever(BaseRetriever):
"""`LlamaIndex` graph data structure retriever.
It is used for question-answering with sources over an LlamaIndex
graph data structure."""
graph: Any = None
"""LlamaIndex graph to query."""
query_configs: List[Dict] = Field(default_factory=list)
"""List of query configs to pass to the query method."""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
try:
from llama_index.core.base.response.schema import Response
from llama_index.core.composability.base import (
QUERY_CONFIG_TYPE,
ComposableGraph,
)
except ImportError:
raise ImportError(
"You need to install `pip install llama-index` to use this retriever."
)
graph = cast(ComposableGraph, self.graph)
# for now, inject response_mode="no_text" into query configs
for query_config in self.query_configs:
query_config["response_mode"] = "no_text"
query_configs = cast(List[QUERY_CONFIG_TYPE], self.query_configs)
response = graph.query(query, query_configs=query_configs)
response = cast(Response, response)
# parse source nodes
docs = []
for source_node in response.source_nodes:
metadata = source_node.metadata or {}
docs.append(
Document(page_content=source_node.get_content(), metadata=metadata)
)
return docs