mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-09-23 20:17:24 +00:00
feat: Qdrant support (#1228)
* feat: Qdrant support * Update private_gpt/components/vector_store/vector_store_component.py
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from chromadb.config import Settings as ChromaSettings
|
||||
from injector import inject, singleton
|
||||
from llama_index import VectorStoreIndex
|
||||
from llama_index.indices.vector_store import VectorIndexRetriever
|
||||
@@ -10,6 +11,9 @@ from llama_index.vector_stores.types import VectorStore
|
||||
from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.paths import local_data_path
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
@@ -36,22 +40,58 @@ class VectorStoreComponent:
|
||||
vector_store: VectorStore
|
||||
|
||||
@inject
|
||||
def __init__(self) -> None:
|
||||
chroma_settings = Settings(anonymized_telemetry=False)
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=str((local_data_path / "chroma_db").absolute()),
|
||||
settings=chroma_settings,
|
||||
)
|
||||
chroma_collection = chroma_client.get_or_create_collection(
|
||||
"make_this_parameterizable_per_api_call"
|
||||
) # TODO
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
match settings.vectorstore.database:
|
||||
case "chroma":
|
||||
chroma_settings = ChromaSettings(anonymized_telemetry=False)
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=str((local_data_path / "chroma_db").absolute()),
|
||||
settings=chroma_settings,
|
||||
)
|
||||
chroma_collection = chroma_client.get_or_create_collection(
|
||||
"make_this_parameterizable_per_api_call"
|
||||
) # TODO
|
||||
|
||||
self.vector_store = typing.cast(
|
||||
VectorStore,
|
||||
BatchedChromaVectorStore(
|
||||
chroma_client=chroma_client, chroma_collection=chroma_collection
|
||||
),
|
||||
)
|
||||
self.vector_store = typing.cast(
|
||||
VectorStore,
|
||||
BatchedChromaVectorStore(
|
||||
chroma_client=chroma_client, chroma_collection=chroma_collection
|
||||
),
|
||||
)
|
||||
|
||||
case "qdrant":
|
||||
try:
|
||||
from llama_index.vector_stores.qdrant import QdrantVectorStore
|
||||
from qdrant_client import QdrantClient # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"'qdrant_client' is not installed."
|
||||
"To use PrivateGPT with Qdrant, install the 'qdrant' extra."
|
||||
"`poetry install --extras qdrant`"
|
||||
) from e
|
||||
if settings.qdrant is None:
|
||||
logger.info(
|
||||
"Qdrant config not found. Using default settings."
|
||||
"Trying to connect to Qdrant at localhost:6333."
|
||||
)
|
||||
client = QdrantClient()
|
||||
else:
|
||||
client = QdrantClient(
|
||||
**settings.qdrant.model_dump(exclude_none=True)
|
||||
)
|
||||
self.vector_store = typing.cast(
|
||||
VectorStore,
|
||||
QdrantVectorStore(
|
||||
client=client,
|
||||
collection_name="make_this_parameterizable_per_api_call",
|
||||
), # TODO
|
||||
)
|
||||
case _:
|
||||
# Should be unreachable
|
||||
# The settings validator should have caught this
|
||||
raise ValueError(
|
||||
f"Vectorstore database {settings.vectorstore.database} not supported"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_retriever(
|
||||
|
Reference in New Issue
Block a user