chore: Add pylint for storage (#1298)

This commit is contained in:
Fangyin Cheng
2024-03-15 15:42:46 +08:00
committed by GitHub
parent a207640ff2
commit 8897d6e8fd
50 changed files with 784 additions and 667 deletions

View File

@@ -1,3 +1,4 @@
"""Vector Store Module."""
from typing import Any

View File

@@ -1,12 +1,13 @@
"""Vector store base class."""
import logging
import math
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, List, Optional
from pydantic import BaseModel, Field
from typing import List, Optional
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import Embeddings
from dbgpt.rag.chunk import Chunk
logger = logging.getLogger(__name__)
@@ -15,6 +16,11 @@ logger = logging.getLogger(__name__)
class VectorStoreConfig(BaseModel):
"""Vector store config."""
class Config:
"""Config for BaseModel."""
arbitrary_types_allowed = True
name: str = Field(
default="dbgpt_collection",
description="The name of vector store, if not set, will use the default name.",
@@ -28,7 +34,7 @@ class VectorStoreConfig(BaseModel):
description="The password of vector store, if not set, will use the default "
"password.",
)
embedding_fn: Optional[Any] = Field(
embedding_fn: Optional[Embeddings] = Field(
default=None,
description="The embedding function of vector store, if not set, will use the "
"default embedding function.",
@@ -47,27 +53,31 @@ class VectorStoreConfig(BaseModel):
class VectorStoreBase(ABC):
"""base class for vector store database"""
"""Vector store base class."""
@abstractmethod
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""load document in vector database.
"""Load document in vector database.
Args:
- chunks: document chunks.
chunks(List[Chunk]): document chunks.
Return:
- ids: chunks ids.
List[str]: chunk ids.
"""
pass
def load_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
) -> List[str]:
"""load document in vector database with limit.
"""Load document in vector database with specified limit.
Args:
chunks: document chunks.
max_chunks_once_load: Max number of chunks to load at once.
max_threads: Max number of threads to use.
chunks(List[Chunk]): Document chunks.
max_chunks_once_load(int): Max number of chunks to load at once.
max_threads(int): Max number of threads to use.
Return:
List[str]: Chunk ids.
"""
# Group the chunks into chunks of size max_chunks
chunk_groups = [
@@ -96,13 +106,15 @@ class VectorStoreBase(ABC):
return ids
@abstractmethod
def similar_search(self, text, topk) -> List[Chunk]:
"""similar search in vector database.
def similar_search(self, text: str, topk: int) -> List[Chunk]:
"""Similar search in vector database.
Args:
- text: query text
- topk: topk
text(str): The query text.
topk(int): The number of similar documents to return.
Return:
- chunks: chunks.
List[Chunk]: The similar documents.
"""
pass
@@ -110,38 +122,43 @@ class VectorStoreBase(ABC):
def similar_search_with_scores(
self, text, topk, score_threshold: float
) -> List[Chunk]:
"""similar search in vector database with scores.
"""Similar search with scores in vector database.
Args:
- text: query text
- topk: topk
- score_threshold: score_threshold: Optional, a floating point value between 0 to 1
text(str): The query text.
topk(int): The number of similar documents to return.
score_threshold(int): score_threshold: Optional, a floating point value
between 0 to 1
Return:
- chunks: chunks.
List[Chunk]: The similar documents.
"""
pass
@abstractmethod
def vector_name_exists(self) -> bool:
"""is vector store name exist."""
"""Whether vector name exists."""
return False
@abstractmethod
def delete_by_ids(self, ids):
"""delete vector by ids.
def delete_by_ids(self, ids: str):
"""Delete vectors by ids.
Args:
- ids: vector ids
ids(str): The ids of vectors to delete, separated by comma.
"""
@abstractmethod
def delete_vector_name(self, vector_name):
"""delete vector name.
def delete_vector_name(self, vector_name: str):
"""Delete vector by name.
Args:
- vector_name: vector store name
vector_name(str): The name of vector to delete.
"""
pass
def _normalization_vectors(self, vectors):
"""normalization vectors to scale[0,1]"""
"""Return L2-normalization vectors to scale[0,1].
Normalization vectors to scale[0,1].
"""
import numpy as np
norm = np.linalg.norm(vectors)

View File

@@ -1,14 +1,18 @@
"""Chroma vector store."""
import logging
import os
from typing import Any, List
from chromadb import PersistentClient
from chromadb.config import Settings
from pydantic import Field
from dbgpt._private.pydantic import Field
from dbgpt.configs.model_config import PILOT_PATH
# TODO: Recycle dependency on rag and storage
from dbgpt.rag.chunk import Chunk
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
from .base import VectorStoreBase, VectorStoreConfig
logger = logging.getLogger(__name__)
@@ -16,20 +20,28 @@ logger = logging.getLogger(__name__)
class ChromaVectorConfig(VectorStoreConfig):
"""Chroma vector store config."""
class Config:
"""Config for BaseModel."""
arbitrary_types_allowed = True
persist_path: str = Field(
default=os.getenv("CHROMA_PERSIST_PATH", None),
description="The password of vector store, if not set, will use the default password.",
description="The password of vector store, if not set, will use the default "
"password.",
)
collection_metadata: dict = Field(
default=None,
description="the index metadata of vector store, if not set, will use the default metadata.",
description="the index metadata of vector store, if not set, will use the "
"default metadata.",
)
class ChromaStore(VectorStoreBase):
"""chroma database"""
"""Chroma vector store."""
def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
"""Create a ChromaStore instance."""
from langchain.vectorstores import Chroma
chroma_vector_config = vector_store_config.dict()
@@ -59,6 +71,7 @@ class ChromaStore(VectorStoreBase):
)
def similar_search(self, text, topk, **kwargs: Any) -> List[Chunk]:
"""Search similar documents."""
logger.info("ChromaStore similar search")
lc_documents = self.vector_store_client.similarity_search(text, topk, **kwargs)
return [
@@ -67,14 +80,16 @@ class ChromaStore(VectorStoreBase):
]
def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]:
"""
"""Search similar documents with scores.
Chroma similar_search_with_score.
Return docs and relevance scores in the range [0, 1].
Args:
text(str): query text
topk(int): return docs nums. Defaults to 4.
score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
score_threshold(float): score_threshold: Optional, a floating point value
between 0 to 1 to filter the resulting set of retrieved docs,0 is
dissimilar, 1 is most similar.
"""
logger.info("ChromaStore similar search with scores")
docs_and_scores = (
@@ -87,8 +102,8 @@ class ChromaStore(VectorStoreBase):
for doc, score in docs_and_scores
]
def vector_name_exists(self):
"""is vector store name exist."""
def vector_name_exists(self) -> bool:
"""Whether vector name exists."""
logger.info(f"Check persist_dir: {self.persist_dir}")
if not os.path.exists(self.persist_dir):
return False
@@ -98,6 +113,7 @@ class ChromaStore(VectorStoreBase):
return len(files) > 0
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document to vector store."""
logger.info("ChromaStore load document")
texts = [chunk.content for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks]
@@ -105,14 +121,16 @@ class ChromaStore(VectorStoreBase):
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas, ids=ids)
return ids
def delete_vector_name(self, vector_name):
def delete_vector_name(self, vector_name: str):
"""Delete vector name."""
logger.info(f"chroma vector_name:{vector_name} begin delete...")
self.vector_store_client.delete_collection()
self._clean_persist_folder()
return True
def delete_by_ids(self, ids):
logger.info(f"begin delete chroma ids...")
"""Delete vector by ids."""
logger.info(f"begin delete chroma ids: {ids}")
ids = ids.split(",")
if len(ids) > 0:
collection = self.vector_store_client._collection

View File

@@ -1,19 +1,26 @@
"""Connector for vector store."""
import os
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Type, cast
from dbgpt.rag.chunk import Chunk
from dbgpt.storage import vector_store
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
connector = {}
connector: Dict[str, Type] = {}
class VectorStoreConnector:
"""The connector for vector store.
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
2.similar_search: similarity search from vector_store
3.similar_search_with_scores: similarity search with similarity score from vector_store
VectorStoreConnector, can connect different vector db provided load document api_v1
and similar search api_v1.
1.load_document:knowledge document source into vector store.(Chroma, Milvus,
Weaviate).
2.similar_search: similarity search from vector_store.
3.similar_search_with_scores: similarity search with similarity score from
vector_store
code example:
>>> from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -23,9 +30,12 @@ class VectorStoreConnector:
"""
def __init__(
self, vector_store_type: str, vector_store_config: VectorStoreConfig = None
self,
vector_store_type: str,
vector_store_config: Optional[VectorStoreConfig] = None,
) -> None:
"""initialize vector store connector.
"""Create a VectorStoreConnector instance.
Args:
- vector_store_type: vector store type Milvus, Chroma, Weaviate
- ctx: vector store config params.
@@ -34,7 +44,7 @@ class VectorStoreConnector:
self._register()
if self._match(vector_store_type):
self.connector_class = connector.get(vector_store_type)
self.connector_class = connector[vector_store_type]
else:
raise Exception(f"Vector Store Type Not support. {0}", vector_store_type)
@@ -44,11 +54,11 @@ class VectorStoreConnector:
@classmethod
def from_default(
cls,
vector_store_type: str = None,
vector_store_type: Optional[str] = None,
embedding_fn: Optional[Any] = None,
vector_store_config: Optional[VectorStoreConfig] = None,
) -> "VectorStoreConnector":
"""initialize default vector store connector."""
"""Initialize default vector store connector."""
vector_store_type = vector_store_type or os.getenv(
"VECTOR_STORE_TYPE", "Chroma"
)
@@ -56,22 +66,33 @@ class VectorStoreConnector:
vector_store_config = vector_store_config or ChromaVectorConfig()
vector_store_config.embedding_fn = embedding_fn
return cls(vector_store_type, vector_store_config)
real_vector_store_type = cast(str, vector_store_type)
return cls(real_vector_store_type, vector_store_config)
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""load document in vector database.
"""Load document in vector database.
Args:
- chunks: document chunks.
Return chunk ids.
"""
max_chunks_once_load = (
self._vector_store_config.max_chunks_once_load
if self._vector_store_config
else 10
)
max_threads = (
self._vector_store_config.max_threads if self._vector_store_config else 1
)
return self.client.load_document_with_limit(
chunks,
self._vector_store_config.max_chunks_once_load,
self._vector_store_config.max_threads,
max_chunks_once_load,
max_threads,
)
def similar_search(self, doc: str, topk: int) -> List[Chunk]:
"""similar search in vector database.
"""Similar search in vector database.
Args:
- doc: query text
- topk: topk
@@ -83,14 +104,17 @@ class VectorStoreConnector:
def similar_search_with_scores(
self, doc: str, topk: int, score_threshold: float
) -> List[Chunk]:
"""
"""Similar search with scores in vector database.
similar_search_with_score in vector database..
Return docs and relevance scores in the range [0, 1].
Args:
- doc(str): query text
- topk(int): return docs nums. Defaults to 4.
- score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
doc(str): query text
topk(int): return docs nums. Defaults to 4.
score_threshold(float): score_threshold: Optional, a floating point value
between 0 to 1 to filter the resulting set of retrieved docs,0 is
dissimilar, 1 is most similar.
Return:
- chunks: chunks.
"""
@@ -98,32 +122,33 @@ class VectorStoreConnector:
@property
def vector_store_config(self) -> VectorStoreConfig:
"""vector store config."""
"""Return the vector store config."""
if not self._vector_store_config:
raise ValueError("vector store config not set.")
return self._vector_store_config
def vector_name_exists(self):
"""is vector store name exist."""
"""Whether vector name exists."""
return self.client.vector_name_exists()
def delete_vector_name(self, vector_name):
"""vector store delete
def delete_vector_name(self, vector_name: str):
"""Delete vector name.
Args:
- vector_name: vector store name
"""
return self.client.delete_vector_name(vector_name)
def delete_by_ids(self, ids):
"""vector store delete by ids.
"""Delete vector by ids.
Args:
- ids: vector ids
"""
return self.client.delete_by_ids(ids=ids)
def _match(self, vector_store_type) -> bool:
if connector.get(vector_store_type):
return True
else:
return False
return bool(connector.get(vector_store_type))
def _register(self):
for cls in vector_store.__all__:

View File

@@ -1,13 +1,14 @@
"""Milvus vector store."""
from __future__ import annotations
import json
import logging
import os
from typing import Any, Iterable, List, Optional, Tuple
from typing import Any, Iterable, List, Optional
from pydantic import Field
from dbgpt.rag.chunk import Chunk, Document
from dbgpt._private.pydantic import Field
from dbgpt.core import Embeddings
from dbgpt.rag.chunk import Chunk
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
from dbgpt.util import string_utils
@@ -17,6 +18,11 @@ logger = logging.getLogger(__name__)
class MilvusVectorConfig(VectorStoreConfig):
"""Milvus vector store config."""
class Config:
"""Config for BaseModel."""
arbitrary_types_allowed = True
uri: str = Field(
default="localhost",
description="The uri of milvus store, if not set, will use the default uri.",
@@ -28,7 +34,8 @@ class MilvusVectorConfig(VectorStoreConfig):
alias: str = Field(
default="default",
description="The alias of milvus store, if not set, will use the default alias.",
description="The alias of milvus store, if not set, will use the default "
"alias.",
)
user: str = Field(
default=None,
@@ -36,35 +43,42 @@ class MilvusVectorConfig(VectorStoreConfig):
)
password: str = Field(
default=None,
description="The password of milvus store, if not set, will use the default password.",
description="The password of milvus store, if not set, will use the default "
"password.",
)
primary_field: str = Field(
default="pk_id",
description="The primary field of milvus store, if not set, will use the default primary field.",
description="The primary field of milvus store, if not set, will use the "
"default primary field.",
)
text_field: str = Field(
default="content",
description="The text field of milvus store, if not set, will use the default text field.",
description="The text field of milvus store, if not set, will use the default "
"text field.",
)
embedding_field: str = Field(
default="vector",
description="The embedding field of milvus store, if not set, will use the default embedding field.",
description="The embedding field of milvus store, if not set, will use the "
"default embedding field.",
)
metadata_field: str = Field(
default="metadata",
description="The metadata field of milvus store, if not set, will use the default metadata field.",
description="The metadata field of milvus store, if not set, will use the "
"default metadata field.",
)
secure: str = Field(
default="",
description="The secure of milvus store, if not set, will use the default secure.",
description="The secure of milvus store, if not set, will use the default "
"secure.",
)
class MilvusStore(VectorStoreBase):
"""Milvus database"""
"""Milvus vector store."""
def __init__(self, vector_store_config: MilvusVectorConfig) -> None:
"""MilvusStore init.
"""Create a MilvusStore instance.
Args:
vector_store_config (MilvusVectorConfig): MilvusStore config.
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
@@ -93,8 +107,11 @@ class MilvusStore(VectorStoreBase):
hex_str = bytes_str.hex()
self.collection_name = hex_str
self.embedding = vector_store_config.embedding_fn
self.fields = []
if not vector_store_config.embedding_fn:
raise ValueError("embedding is required for MilvusStore")
self.embedding: Embeddings = vector_store_config.embedding_fn
self.fields: List = []
self.alias = milvus_vector_config.get("alias") or "default"
# use HNSW by default.
@@ -124,7 +141,8 @@ class MilvusStore(VectorStoreBase):
if (self.username is None) != (self.password is None):
raise ValueError(
"Both username and password must be set to use authentication for Milvus"
"Both username and password must be set to use authentication for "
"Milvus"
)
if self.username:
connect_kwargs["user"] = self.username
@@ -139,7 +157,10 @@ class MilvusStore(VectorStoreBase):
)
def init_schema_and_load(self, vector_name, documents) -> List[str]:
"""Create a Milvus collection, indexes it with HNSW, load document.
"""Create a Milvus collection.
Create a Milvus collection, indexes it with HNSW, load document.
Args:
vector_name (Embeddings): your collection name.
documents (List[str]): Text to insert.
@@ -155,7 +176,7 @@ class MilvusStore(VectorStoreBase):
connections,
utility,
)
from pymilvus.orm.types import infer_dtype_bydata
from pymilvus.orm.types import infer_dtype_bydata # noqa: F401
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
@@ -240,10 +261,10 @@ class MilvusStore(VectorStoreBase):
partition_name: Optional[str] = None,
timeout: Optional[int] = None,
) -> List[str]:
"""add text data into Milvus."""
"""Add text data into Milvus."""
insert_dict: Any = {self.text_field: list(texts)}
try:
import numpy as np
import numpy as np # noqa: F401
text_vector = self.embedding.embed_documents(list(texts))
insert_dict[self.vector_field] = text_vector
@@ -268,7 +289,7 @@ class MilvusStore(VectorStoreBase):
return res.primary_keys
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""load document in vector database."""
"""Load document in vector database."""
batch_size = 500
batched_list = [
chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)
@@ -280,6 +301,7 @@ class MilvusStore(VectorStoreBase):
return doc_ids
def similar_search(self, text, topk) -> List[Chunk]:
"""Perform a search on a query string and return results."""
from pymilvus import Collection, DataType
"""similar_search in vector database."""
@@ -409,12 +431,14 @@ class MilvusStore(VectorStoreBase):
return ret[0], ret
def vector_name_exists(self):
"""Whether vector name exists."""
from pymilvus import utility
"""is vector store name exist."""
return utility.has_collection(self.collection_name)
def delete_vector_name(self, vector_name):
def delete_vector_name(self, vector_name: str):
"""Delete vector name."""
from pymilvus import utility
"""milvus delete collection name"""
@@ -423,11 +447,12 @@ class MilvusStore(VectorStoreBase):
return True
def delete_by_ids(self, ids):
"""Delete vector by ids."""
from pymilvus import Collection
self.col = Collection(self.collection_name)
"""milvus delete vectors by ids"""
logger.info(f"begin delete milvus ids...")
# milvus delete vectors by ids
logger.info(f"begin delete milvus ids: {ids}")
delete_ids = ids.split(",")
doc_ids = [int(doc_id) for doc_id in delete_ids]
delet_expr = f"{self.primary_field} in {doc_ids}"

View File

@@ -1,9 +1,9 @@
"""Postgres vector store."""
import logging
from typing import Any, List
from pydantic import Field
from dbgpt._private.config import Config
from dbgpt._private.pydantic import Field
from dbgpt.rag.chunk import Chunk
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
@@ -15,21 +15,26 @@ CFG = Config()
class PGVectorConfig(VectorStoreConfig):
"""PG vector store config."""
class Config:
"""Config for BaseModel."""
arbitrary_types_allowed = True
connection_string: str = Field(
default=None,
description="the connection string of vector store, if not set, will use the default connection string.",
description="the connection string of vector store, if not set, will use the "
"default connection string.",
)
class PGVectorStore(VectorStoreBase):
"""`Postgres.PGVector` vector store.
"""PG vector store.
To use this, you should have the ``pgvector`` python package installed.
"""
def __init__(self, vector_store_config: PGVectorConfig) -> None:
"""init pgvector storage"""
"""Create a PGVectorStore instance."""
from langchain.vectorstores import PGVector
self.connection_string = vector_store_config.connection_string
@@ -42,23 +47,43 @@ class PGVectorStore(VectorStoreBase):
connection_string=self.connection_string,
)
def similar_search(self, text, topk, **kwargs: Any) -> None:
def similar_search(self, text: str, topk: int, **kwargs: Any) -> List[Chunk]:
"""Perform similar search in PGVector."""
return self.vector_store_client.similarity_search(text, topk)
def vector_name_exists(self):
def vector_name_exists(self) -> bool:
"""Check if vector name exists."""
try:
self.vector_store_client.create_collection()
return True
except Exception as e:
logger.error("vector_name_exists error", e.message)
logger.error(f"vector_name_exists error, {str(e)}")
return False
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document to PGVector.
Args:
chunks(List[Chunk]): document chunks.
Return:
List[str]: chunk ids.
"""
lc_documents = [Chunk.chunk2langchain(chunk) for chunk in chunks]
return self.vector_store_client.from_documents(lc_documents)
def delete_vector_name(self, vector_name):
def delete_vector_name(self, vector_name: str):
"""Delete vector by name.
Args:
vector_name(str): vector name.
"""
return self.vector_store_client.delete_collection()
def delete_by_ids(self, ids):
def delete_by_ids(self, ids: str):
"""Delete vector by ids.
Args:
ids(str): vector ids, separated by comma.
"""
return self.vector_store_client.delete(ids)

View File

@@ -1,14 +1,13 @@
"""Weaviate vector store."""
import logging
import os
from typing import List
from langchain.schema import Document
from pydantic import Field
from dbgpt._private.config import Config
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from dbgpt._private.pydantic import Field
from dbgpt.rag.chunk import Chunk
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
from .base import VectorStoreBase, VectorStoreConfig
logger = logging.getLogger(__name__)
CFG = Config()
@@ -17,6 +16,11 @@ CFG = Config()
class WeaviateVectorConfig(VectorStoreConfig):
"""Weaviate vector store config."""
class Config:
"""Config for BaseModel."""
arbitrary_types_allowed = True
weaviate_url: str = Field(
default=os.getenv("WEAVIATE_URL", None),
description="weaviate url address, if not set, will use the default url.",
@@ -28,7 +32,7 @@ class WeaviateVectorConfig(VectorStoreConfig):
class WeaviateStore(VectorStoreBase):
"""Weaviate database"""
"""Weaviate database."""
def __init__(self, vector_store_config: WeaviateVectorConfig) -> None:
"""Initialize with Weaviate client."""
@@ -49,8 +53,8 @@ class WeaviateStore(VectorStoreBase):
self.vector_store_client = weaviate.Client(self.weaviate_url)
def similar_search(self, text: str, topk: int) -> None:
"""Perform similar search in Weaviate"""
def similar_search(self, text: str, topk: int) -> List[Chunk]:
"""Perform similar search in Weaviate."""
logger.info("Weaviate similar search")
# nearText = {
# "concepts": [text],
@@ -68,15 +72,16 @@ class WeaviateStore(VectorStoreBase):
docs = []
for r in res:
docs.append(
Document(
page_content=r["page_content"],
Chunk(
content=r["page_content"],
metadata={"metadata": r["metadata"]},
)
)
return docs
def vector_name_exists(self) -> bool:
"""Check if a vector name exists for a given class in Weaviate.
"""Whether the vector name exists in Weaviate.
Returns:
bool: True if the vector name exists, False otherwise.
"""
@@ -85,14 +90,15 @@ class WeaviateStore(VectorStoreBase):
return True
return False
except Exception as e:
logger.error("vector_name_exists error", e.message)
logger.error(f"vector_name_exists error, {str(e)}")
return False
def _default_schema(self) -> None:
"""
Create the schema for Weaviate with a Document class containing metadata and text properties.
"""
"""Create default schema in Weaviate.
Create the schema for Weaviate with a Document class containing metadata and
text properties.
"""
schema = {
"classes": [
{
@@ -137,7 +143,7 @@ class WeaviateStore(VectorStoreBase):
self.vector_store_client.schema.create(schema)
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load documents into Weaviate"""
"""Load document to Weaviate."""
logger.info("Weaviate load document")
texts = [doc.content for doc in chunks]
metadatas = [doc.metadata for doc in chunks]
@@ -157,3 +163,5 @@ class WeaviateStore(VectorStoreBase):
data_object=properties, class_name=self.vector_name
)
self.vector_store_client.batch.flush()
# TODO: return ids
return []