mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-28 04:51:29 +00:00
refactor: RAG Refactor (#985)
Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -1,18 +1,69 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import math
|
||||
from typing import Optional, Callable, List, Any
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
"""Vector store config."""
|
||||
|
||||
name: str = Field(
|
||||
default="dbgpt_collection",
|
||||
description="The name of vector store, if not set, will use the default name.",
|
||||
)
|
||||
user: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The user of vector store, if not set, will use the default user.",
|
||||
)
|
||||
password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password of vector store, if not set, will use the default password.",
|
||||
)
|
||||
embedding_fn: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="The embedding function of vector store, if not set, will use the default embedding function.",
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreBase(ABC):
|
||||
"""base class for vector store database"""
|
||||
|
||||
@abstractmethod
|
||||
def load_document(self, documents) -> None:
|
||||
"""load document in vector database."""
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""load document in vector database.
|
||||
Args:
|
||||
- chunks: document chunks.
|
||||
Return:
|
||||
- ids: chunks ids.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def similar_search(self, text, topk) -> None:
|
||||
"""similar search in vector database."""
|
||||
def similar_search(self, text, topk) -> List[Chunk]:
|
||||
"""similar search in vector database.
|
||||
Args:
|
||||
- text: query text
|
||||
- topk: topk
|
||||
Return:
|
||||
- chunks: chunks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def similar_search_with_scores(
|
||||
self, text, topk, score_threshold: float
|
||||
) -> List[Chunk]:
|
||||
"""similar search in vector database with scores.
|
||||
Args:
|
||||
- text: query text
|
||||
- topk: topk
|
||||
- score_threshold: score_threshold: Optional, a floating point value between 0 to 1
|
||||
Return:
|
||||
- chunks: chunks.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -22,12 +73,17 @@ class VectorStoreBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids):
|
||||
"""delete vector by ids."""
|
||||
pass
|
||||
"""delete vector by ids.
|
||||
Args:
|
||||
- ids: vector ids
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_vector_name(self, vector_name):
|
||||
"""delete vector name."""
|
||||
"""delete vector name.
|
||||
Args:
|
||||
- vector_name: vector store name
|
||||
"""
|
||||
pass
|
||||
|
||||
def _normalization_vectors(self, vectors):
|
||||
|
@@ -1,30 +1,45 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
from chromadb.config import Settings
|
||||
from chromadb import PersistentClient
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
from pydantic import Field
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChromaVectorConfig(VectorStoreConfig):
|
||||
"""Chroma vector store config."""
|
||||
|
||||
persist_path: str = Field(
|
||||
default=os.getenv("CHROMA_PERSIST_PATH", None),
|
||||
description="The password of vector store, if not set, will use the default password.",
|
||||
)
|
||||
collection_metadata: dict = Field(
|
||||
default=None,
|
||||
description="the index metadata of vector store, if not set, will use the default metadata.",
|
||||
)
|
||||
|
||||
|
||||
class ChromaStore(VectorStoreBase):
|
||||
"""chroma database"""
|
||||
|
||||
def __init__(self, ctx: {}) -> None:
|
||||
def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
self.ctx = ctx
|
||||
chroma_path = ctx.get(
|
||||
"CHROMA_PERSIST_PATH",
|
||||
os.path.join(PILOT_PATH, "data"),
|
||||
chroma_vector_config = vector_store_config.dict()
|
||||
chroma_path = chroma_vector_config.get(
|
||||
"persist_path", os.path.join(PILOT_PATH, "data")
|
||||
)
|
||||
self.persist_dir = os.path.join(
|
||||
chroma_path, ctx["vector_store_name"] + ".vectordb"
|
||||
chroma_path, vector_store_config.name + ".vectordb"
|
||||
)
|
||||
self.embeddings = ctx.get("embeddings", None)
|
||||
self.embeddings = vector_store_config.embedding_fn
|
||||
chroma_settings = Settings(
|
||||
# chroma_db_impl="duckdb+parquet", => deprecated configuration of Chroma
|
||||
persist_directory=self.persist_dir,
|
||||
@@ -32,7 +47,9 @@ class ChromaStore(VectorStoreBase):
|
||||
)
|
||||
client = PersistentClient(path=self.persist_dir, settings=chroma_settings)
|
||||
|
||||
collection_metadata = {"hnsw:space": "cosine"}
|
||||
collection_metadata = chroma_vector_config.get("collection_metadata") or {
|
||||
"hnsw:space": "cosine"
|
||||
}
|
||||
self.vector_store_client = Chroma(
|
||||
persist_directory=self.persist_dir,
|
||||
embedding_function=self.embeddings,
|
||||
@@ -41,11 +58,15 @@ class ChromaStore(VectorStoreBase):
|
||||
collection_metadata=collection_metadata,
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk, **kwargs: Any) -> None:
|
||||
def similar_search(self, text, topk, **kwargs: Any) -> List[Chunk]:
|
||||
logger.info("ChromaStore similar search")
|
||||
return self.vector_store_client.similarity_search(text, topk, **kwargs)
|
||||
lc_documents = self.vector_store_client.similarity_search(text, topk, **kwargs)
|
||||
return [
|
||||
Chunk(content=doc.page_content, metadata=doc.metadata)
|
||||
for doc in lc_documents
|
||||
]
|
||||
|
||||
def similar_search_with_scores(self, text, topk, score_threshold) -> None:
|
||||
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]:
|
||||
"""
|
||||
Chroma similar_search_with_score.
|
||||
Return docs and relevance scores in the range [0, 1].
|
||||
@@ -55,15 +76,19 @@ class ChromaStore(VectorStoreBase):
|
||||
score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
|
||||
"""
|
||||
logger.info("ChromaStore similar search")
|
||||
logger.info("ChromaStore similar search with scores")
|
||||
docs_and_scores = (
|
||||
self.vector_store_client.similarity_search_with_relevance_scores(
|
||||
query=text, k=topk, score_threshold=score_threshold
|
||||
)
|
||||
)
|
||||
return docs_and_scores
|
||||
return [
|
||||
Chunk(content=doc.page_content, metadata=doc.metadata, score=score)
|
||||
for doc, score in docs_and_scores
|
||||
]
|
||||
|
||||
def vector_name_exists(self):
|
||||
"""is vector store name exist."""
|
||||
logger.info(f"Check persist_dir: {self.persist_dir}")
|
||||
if not os.path.exists(self.persist_dir):
|
||||
return False
|
||||
@@ -72,11 +97,12 @@ class ChromaStore(VectorStoreBase):
|
||||
files = list(filter(lambda f: f != "chroma.sqlite3", files))
|
||||
return len(files) > 0
|
||||
|
||||
def load_document(self, documents):
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
logger.info("ChromaStore load document")
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
ids = self.vector_store_client.add_texts(texts=texts, metadatas=metadatas)
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
metadatas = [chunk.metadata for chunk in chunks]
|
||||
ids = [chunk.chunk_id for chunk in chunks]
|
||||
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
||||
return ids
|
||||
|
||||
def delete_vector_name(self, vector_name):
|
||||
|
@@ -1,54 +1,94 @@
|
||||
import os
|
||||
from typing import Optional, List, Callable, Any
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.storage import vector_store
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
|
||||
connector = {}
|
||||
|
||||
|
||||
class VectorStoreConnector:
|
||||
|
||||
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
|
||||
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
|
||||
2.similar_search: similarity search from vector_store
|
||||
3.similar_search_with_scores: similarity search with similarity score from vector_store
|
||||
|
||||
code example:
|
||||
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
>>> vector_store_config = VectorStoreConfig
|
||||
>>> vector_store_connector = VectorStoreConnector(vector_store_type="Chroma")
|
||||
"""
|
||||
|
||||
def __init__(self, vector_store_type, ctx: {}) -> None:
|
||||
def __init__(
|
||||
self, vector_store_type: str, vector_store_config: VectorStoreConfig = None
|
||||
) -> None:
|
||||
"""initialize vector store connector.
|
||||
Args:
|
||||
- vector_store_type: vector store type Milvus, Chroma, Weaviate
|
||||
- ctx: vector store config params.
|
||||
"""
|
||||
self.ctx = ctx
|
||||
self._vector_store_config = vector_store_config
|
||||
self._register()
|
||||
|
||||
if self._match(vector_store_type):
|
||||
self.connector_class = connector.get(vector_store_type)
|
||||
else:
|
||||
raise Exception(f"Vector Type Not support. {0}", vector_store_type)
|
||||
raise Exception(f"Vector Store Type Not support. {0}", vector_store_type)
|
||||
|
||||
print(self.connector_class)
|
||||
self.client = self.connector_class(ctx)
|
||||
self.client = self.connector_class(vector_store_config)
|
||||
|
||||
def load_document(self, docs):
|
||||
"""load document in vector database."""
|
||||
return self.client.load_document(docs)
|
||||
@classmethod
|
||||
def from_default(
|
||||
cls,
|
||||
vector_store_type: str = None,
|
||||
embedding_fn: Optional[Any] = None,
|
||||
vector_store_config: Optional[VectorStoreConfig] = None,
|
||||
) -> "VectorStoreConnector":
|
||||
"""initialize default vector store connector."""
|
||||
vector_store_type = vector_store_type or os.getenv(
|
||||
"VECTOR_STORE_TYPE", "Chroma"
|
||||
)
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
|
||||
def similar_search(self, doc: str, topk: int):
|
||||
vector_store_config = vector_store_config or ChromaVectorConfig()
|
||||
vector_store_config.embedding_fn = embedding_fn
|
||||
return cls(vector_store_type, vector_store_config)
|
||||
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""load document in vector database.
|
||||
Args:
|
||||
- chunks: document chunks.
|
||||
Return chunk ids.
|
||||
"""
|
||||
return self.client.load_document(chunks)
|
||||
|
||||
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
|
||||
"""similar search in vector database.
|
||||
Args:
|
||||
- doc: query text
|
||||
- topk: topk
|
||||
Return:
|
||||
- chunks: chunks.
|
||||
"""
|
||||
return self.client.similar_search(doc, topk)
|
||||
|
||||
def similar_search_with_scores(self, doc: str, topk: int, score_threshold: float):
|
||||
def similar_search_with_scores(
|
||||
self, doc: str, topk: int, score_threshold: float
|
||||
) -> List[Chunk]:
|
||||
"""
|
||||
similar_search_with_score in vector database..
|
||||
Return docs and relevance scores in the range [0, 1].
|
||||
Args:
|
||||
doc(str): query text
|
||||
topk(int): return docs nums. Defaults to 4.
|
||||
score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
- doc(str): query text
|
||||
- topk(int): return docs nums. Defaults to 4.
|
||||
- score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
|
||||
Return:
|
||||
- chunks: chunks.
|
||||
"""
|
||||
return self.client.similar_search_with_scores(doc, topk, score_threshold)
|
||||
|
||||
|
@@ -5,49 +5,111 @@ import logging
|
||||
import os
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
from dbgpt.rag.chunk import Chunk, Document
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
from dbgpt.util import string_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusVectorConfig(VectorStoreConfig):
|
||||
"""Milvus vector store config."""
|
||||
|
||||
uri: str = Field(
|
||||
default="localhost",
|
||||
description="The uri of milvus store, if not set, will use the default uri.",
|
||||
)
|
||||
port: str = Field(
|
||||
default="19530",
|
||||
description="The port of milvus store, if not set, will use the default port.",
|
||||
)
|
||||
|
||||
alias: str = Field(
|
||||
default="default",
|
||||
description="The alias of milvus store, if not set, will use the default alias.",
|
||||
)
|
||||
user: str = Field(
|
||||
default=None,
|
||||
description="The user of milvus store, if not set, will use the default user.",
|
||||
)
|
||||
password: str = Field(
|
||||
default=None,
|
||||
description="The password of milvus store, if not set, will use the default password.",
|
||||
)
|
||||
primary_field: str = Field(
|
||||
default="pk_id",
|
||||
description="The primary field of milvus store, if not set, will use the default primary field.",
|
||||
)
|
||||
text_field: str = Field(
|
||||
default="content",
|
||||
description="The text field of milvus store, if not set, will use the default text field.",
|
||||
)
|
||||
embedding_field: str = Field(
|
||||
default="vector",
|
||||
description="The embedding field of milvus store, if not set, will use the default embedding field.",
|
||||
)
|
||||
metadata_field: str = Field(
|
||||
default="metadata",
|
||||
description="The metadata field of milvus store, if not set, will use the default metadata field.",
|
||||
)
|
||||
secure: str = Field(
|
||||
default="",
|
||||
description="The secure of milvus store, if not set, will use the default secure.",
|
||||
)
|
||||
|
||||
|
||||
class MilvusStore(VectorStoreBase):
|
||||
"""Milvus database"""
|
||||
|
||||
def __init__(self, ctx: {}) -> None:
|
||||
"""MilvusStore init."""
|
||||
def __init__(self, vector_store_config: MilvusVectorConfig) -> None:
|
||||
"""MilvusStore init.
|
||||
Args:
|
||||
vector_store_config (MilvusVectorConfig): MilvusStore config.
|
||||
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
|
||||
"""
|
||||
from pymilvus import connections
|
||||
|
||||
"""init a milvus storage connection.
|
||||
|
||||
Args:
|
||||
ctx ({}): MilvusStore global config.
|
||||
"""
|
||||
# self.configure(cfg)
|
||||
|
||||
connect_kwargs = {}
|
||||
self.uri = ctx.get("MILVUS_URL", os.getenv("MILVUS_URL"))
|
||||
self.port = ctx.get("MILVUS_PORT", os.getenv("MILVUS_PORT"))
|
||||
self.username = ctx.get("MILVUS_USERNAME", os.getenv("MILVUS_USERNAME"))
|
||||
self.password = ctx.get("MILVUS_PASSWORD", os.getenv("MILVUS_PASSWORD"))
|
||||
self.secure = ctx.get("MILVUS_SECURE", os.getenv("MILVUS_SECURE"))
|
||||
self.collection_name = ctx.get("vector_store_name", None)
|
||||
self.embedding = ctx.get("embeddings", None)
|
||||
milvus_vector_config = vector_store_config.dict()
|
||||
self.uri = milvus_vector_config.get("uri") or os.getenv(
|
||||
"MILVUS_URL", "localhost"
|
||||
)
|
||||
self.port = milvus_vector_config.get("post") or os.getenv(
|
||||
"MILVUS_PORT", "19530"
|
||||
)
|
||||
self.username = milvus_vector_config.get("user") or os.getenv("MILVUS_USER")
|
||||
self.password = milvus_vector_config.get("password") or os.getenv(
|
||||
"MILVUS_PASSWORD"
|
||||
)
|
||||
self.secure = milvus_vector_config.get("secure") or os.getenv("MILVUS_SECURE")
|
||||
|
||||
self.collection_name = (
|
||||
milvus_vector_config.get("name") or vector_store_config.name
|
||||
)
|
||||
if string_utils.is_all_chinese(self.collection_name):
|
||||
bytes_str = self.collection_name.encode("utf-8")
|
||||
hex_str = bytes_str.hex()
|
||||
self.collection_name = hex_str
|
||||
|
||||
self.embedding = vector_store_config.embedding_fn
|
||||
self.fields = []
|
||||
self.alias = "default"
|
||||
self.alias = milvus_vector_config.get("alias") or "default"
|
||||
|
||||
# use HNSW by default.
|
||||
self.index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "HNSW",
|
||||
"metric_type": "COSINE",
|
||||
"params": {"M": 8, "efConstruction": 64},
|
||||
}
|
||||
|
||||
# use HNSW by default.
|
||||
self.index_params_map = {
|
||||
"IVF_FLAT": {"params": {"nprobe": 10}},
|
||||
"IVF_SQ8": {"params": {"nprobe": 10}},
|
||||
"IVF_PQ": {"params": {"nprobe": 10}},
|
||||
"HNSW": {"params": {"ef": 10}},
|
||||
"HNSW": {"params": {"M": 8, "efConstruction": 64}},
|
||||
"RHNSW_FLAT": {"params": {"ef": 10}},
|
||||
"RHNSW_SQ": {"params": {"ef": 10}},
|
||||
"RHNSW_PQ": {"params": {"ef": 10}},
|
||||
@@ -55,10 +117,10 @@ class MilvusStore(VectorStoreBase):
|
||||
"ANNOY": {"params": {"search_k": 10}},
|
||||
}
|
||||
# default collection schema
|
||||
self.primary_field = "pk_id"
|
||||
self.vector_field = "vector"
|
||||
self.text_field = "content"
|
||||
self.metadata_field = "metadata"
|
||||
self.primary_field = milvus_vector_config.get("primary_field") or "pk_id"
|
||||
self.vector_field = milvus_vector_config.get("embedding_field") or "vector"
|
||||
self.text_field = milvus_vector_config.get("text_field") or "content"
|
||||
self.metadata_field = milvus_vector_config.get("metadata_field") or "metadata"
|
||||
|
||||
if (self.username is None) != (self.password is None):
|
||||
raise ValueError(
|
||||
@@ -75,13 +137,13 @@ class MilvusStore(VectorStoreBase):
|
||||
# secure=self.secure,
|
||||
)
|
||||
|
||||
def init_schema_and_load(self, vector_name, documents):
|
||||
def init_schema_and_load(self, vector_name, documents) -> List[str]:
|
||||
"""Create a Milvus collection, indexes it with HNSW, load document.
|
||||
Args:
|
||||
vector_name (Embeddings): your collection name.
|
||||
documents (List[str]): Text to insert.
|
||||
Returns:
|
||||
VectorStore: The MilvusStore vector store.
|
||||
List[str]: document ids.
|
||||
"""
|
||||
try:
|
||||
from pymilvus import (
|
||||
@@ -105,7 +167,7 @@ class MilvusStore(VectorStoreBase):
|
||||
alias="default"
|
||||
# secure=self.secure,
|
||||
)
|
||||
texts = [d.page_content for d in documents]
|
||||
texts = [d.content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
embeddings = self.embedding.embed_query(texts[0])
|
||||
|
||||
@@ -183,7 +245,7 @@ class MilvusStore(VectorStoreBase):
|
||||
import numpy as np
|
||||
|
||||
text_vector = self.embedding.embed_documents(list(texts))
|
||||
insert_dict[self.vector_field] = self._normalization_vectors(text_vector)
|
||||
insert_dict[self.vector_field] = text_vector
|
||||
except NotImplementedError:
|
||||
insert_dict[self.vector_field] = [
|
||||
self.embedding.embed_query(x) for x in texts
|
||||
@@ -204,12 +266,11 @@ class MilvusStore(VectorStoreBase):
|
||||
self.col.flush()
|
||||
return res.primary_keys
|
||||
|
||||
def load_document(self, documents) -> None:
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""load document in vector database."""
|
||||
# self.init_schema_and_load(self.collection_name, documents)
|
||||
batch_size = 500
|
||||
batched_list = [
|
||||
documents[i : i + batch_size] for i in range(0, len(documents), batch_size)
|
||||
chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)
|
||||
]
|
||||
doc_ids = []
|
||||
for doc_batch in batched_list:
|
||||
@@ -217,7 +278,7 @@ class MilvusStore(VectorStoreBase):
|
||||
doc_ids = [str(doc_id) for doc_id in doc_ids]
|
||||
return doc_ids
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
def similar_search(self, text, topk) -> List[Chunk]:
|
||||
from pymilvus import Collection, DataType
|
||||
|
||||
"""similar_search in vector database."""
|
||||
@@ -232,17 +293,16 @@ class MilvusStore(VectorStoreBase):
|
||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||
self.vector_field = x.name
|
||||
_, docs_and_scores = self._search(text, topk)
|
||||
from langchain.schema import Document
|
||||
|
||||
return [
|
||||
Document(
|
||||
Chunk(
|
||||
metadata=json.loads(doc.metadata.get("metadata", "")),
|
||||
page_content=doc.page_content,
|
||||
content=doc.content,
|
||||
)
|
||||
for doc, _, _ in docs_and_scores
|
||||
]
|
||||
|
||||
def similar_search_with_scores(self, text, topk, score_threshold):
|
||||
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]:
|
||||
"""Perform a search on a query string and return results with score.
|
||||
|
||||
For more information about the search parameters, take a look at the pymilvus
|
||||
@@ -286,7 +346,12 @@ class MilvusStore(VectorStoreBase):
|
||||
|
||||
if score_threshold is not None:
|
||||
docs_and_scores = [
|
||||
(doc, score)
|
||||
Chunk(
|
||||
metadata=doc.metadata,
|
||||
content=doc.content,
|
||||
score=score,
|
||||
chunk_id=id,
|
||||
)
|
||||
for doc, score, id in docs_and_scores
|
||||
if score >= score_threshold
|
||||
]
|
||||
@@ -308,22 +373,19 @@ class MilvusStore(VectorStoreBase):
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
self.col.load()
|
||||
# use default index params.
|
||||
if param is None:
|
||||
index_type = self.col.indexes[0].params["index_type"]
|
||||
param = self.index_params_map[index_type]
|
||||
param = self.index_params_map[index_type].get("params")
|
||||
# query text embedding.
|
||||
query_vector = self.embedding.embed_query(query)
|
||||
data = [self._normalization_vectors(query_vector)]
|
||||
# Determine result metadata fields.
|
||||
output_fields = self.fields[:]
|
||||
output_fields.remove(self.vector_field)
|
||||
# milvus search.
|
||||
res = self.col.search(
|
||||
data,
|
||||
[query_vector],
|
||||
self.vector_field,
|
||||
param,
|
||||
k,
|
||||
@@ -339,13 +401,13 @@ class MilvusStore(VectorStoreBase):
|
||||
meta = {x: result.entity.get(x) for x in output_fields}
|
||||
ret.append(
|
||||
(
|
||||
Document(page_content=meta.pop(self.text_field), metadata=meta),
|
||||
self._default_relevance_score_fn(result.distance),
|
||||
Chunk(content=meta.pop(self.text_field), metadata=meta),
|
||||
result.distance,
|
||||
result.id,
|
||||
)
|
||||
)
|
||||
|
||||
return data[0], ret
|
||||
return ret[0], ret
|
||||
|
||||
def vector_name_exists(self):
|
||||
from pymilvus import utility
|
||||
|
@@ -1,6 +1,10 @@
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
import logging
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -8,21 +12,29 @@ logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class PGVectorConfig(VectorStoreConfig):
|
||||
"""PG vector store config."""
|
||||
|
||||
connection_string: str = Field(
|
||||
default=None,
|
||||
description="the connection string of vector store, if not set, will use the default connection string.",
|
||||
)
|
||||
|
||||
|
||||
class PGVectorStore(VectorStoreBase):
|
||||
"""`Postgres.PGVector` vector store.
|
||||
|
||||
To use this, you should have the ``pgvector`` python package installed.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: dict) -> None:
|
||||
def __init__(self, vector_store_config: PGVectorConfig) -> None:
|
||||
"""init pgvector storage"""
|
||||
|
||||
from langchain.vectorstores import PGVector
|
||||
|
||||
self.ctx = ctx
|
||||
self.connection_string = ctx.get("connection_string", None)
|
||||
self.embeddings = ctx.get("embeddings", None)
|
||||
self.collection_name = ctx.get("vector_store_name", None)
|
||||
self.connection_string = vector_store_config.connection_string
|
||||
self.embeddings = vector_store_config.embedding_fn
|
||||
self.collection_name = vector_store_config.name
|
||||
|
||||
self.vector_store_client = PGVector(
|
||||
embedding_function=self.embeddings,
|
||||
@@ -41,8 +53,9 @@ class PGVectorStore(VectorStoreBase):
|
||||
logger.error("vector_name_exists error", e.message)
|
||||
return False
|
||||
|
||||
def load_document(self, documents) -> None:
|
||||
return self.vector_store_client.from_documents(documents)
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
lc_documents = [Chunk.chunk2langchain(chunk) for chunk in chunks]
|
||||
return self.vector_store_client.from_documents(lc_documents)
|
||||
|
||||
def delete_vector_name(self, vector_name):
|
||||
return self.vector_store_client.delete_collection()
|
||||
|
@@ -1,19 +1,36 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import Document
|
||||
from pydantic import Field
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class WeaviateVectorConfig(VectorStoreConfig):
|
||||
"""Weaviate vector store config."""
|
||||
|
||||
weaviate_url: str = Field(
|
||||
default=os.getenv("WEAVIATE_URL", None),
|
||||
description="weaviate url address, if not set, will use the default url.",
|
||||
)
|
||||
persist_path: str = Field(
|
||||
default=os.getenv("WEAVIATE_PERSIST_PATH", None),
|
||||
description="weaviate persist path.",
|
||||
)
|
||||
|
||||
|
||||
class WeaviateStore(VectorStoreBase):
|
||||
"""Weaviate database"""
|
||||
|
||||
def __init__(self, ctx: dict) -> None:
|
||||
def __init__(self, vector_store_config: WeaviateVectorConfig) -> None:
|
||||
"""Initialize with Weaviate client."""
|
||||
try:
|
||||
import weaviate
|
||||
@@ -23,12 +40,11 @@ class WeaviateStore(VectorStoreBase):
|
||||
"Please install it with `pip install weaviate-client`."
|
||||
)
|
||||
|
||||
self.ctx = ctx
|
||||
self.weaviate_url = ctx.get("WEAVIATE_URL", os.getenv("WEAVIATE_URL"))
|
||||
self.embedding = ctx.get("embeddings", None)
|
||||
self.vector_name = ctx["vector_store_name"]
|
||||
self.weaviate_url = vector_store_config.weaviate_url
|
||||
self.embedding = vector_store_config.embedding_fn
|
||||
self.vector_name = vector_store_config.name
|
||||
self.persist_dir = os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, self.vector_name + ".vectordb"
|
||||
vector_store_config.persist_path, vector_store_config.name + ".vectordb"
|
||||
)
|
||||
|
||||
self.vector_store_client = weaviate.Client(self.weaviate_url)
|
||||
@@ -120,11 +136,11 @@ class WeaviateStore(VectorStoreBase):
|
||||
# Create the schema in Weaviate
|
||||
self.vector_store_client.schema.create(schema)
|
||||
|
||||
def load_document(self, documents: list) -> None:
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load documents into Weaviate"""
|
||||
logger.info("Weaviate load document")
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
texts = [doc.content for doc in chunks]
|
||||
metadatas = [doc.metadata for doc in chunks]
|
||||
|
||||
# Import data
|
||||
with self.vector_store_client.batch as batch:
|
||||
@@ -134,7 +150,7 @@ class WeaviateStore(VectorStoreBase):
|
||||
for i in range(len(texts)):
|
||||
properties = {
|
||||
"metadata": metadatas[i]["source"],
|
||||
"page_content": texts[i],
|
||||
"content": texts[i],
|
||||
}
|
||||
|
||||
self.vector_store_client.batch.add_data_object(
|
||||
|
Reference in New Issue
Block a user