diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 9e6542db9..594b8b4ae 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -150,6 +150,8 @@ class Config(metaclass=Singleton): self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) + self.WEAVIATE_URL = os.getenv("WEAVIATE_URL", "http://127.0.0.1:8080") + # QLoRA self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") @@ -158,7 +160,7 @@ class Config(metaclass=Singleton): self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100)) self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5)) ### SUMMARY_CONFIG Configuration - self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR") + self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST") def set_debug_mode(self, value: bool) -> None: """Set the debug mode value""" diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index a56b2a098..f6582d343 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -54,6 +54,7 @@ class ChatNewKnowledge(BaseChat): self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE ) context = [d.page_content for d in docs] + self.metadata = [d.metadata for d in docs] context = context[:2000] input_values = {"context": context, "question": self.current_user_input} return input_values diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index 6672d3d23..8ba6df253 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -1,8 +1,9 @@ from pilot.vector_store.chroma_store import ChromaStore -# from pilot.vector_store.milvus_store import MilvusStore +from pilot.vector_store.milvus_store import MilvusStore +from pilot.vector_store.weaviate_store import WeaviateStore -connector = {"Chroma": ChromaStore, "Milvus": None} +connector = {"Chroma": ChromaStore, "Milvus": MilvusStore, "Weaviate": WeaviateStore} class VectorStoreConnector: diff --git a/pilot/vector_store/weaviate_store.py b/pilot/vector_store/weaviate_store.py index e208dde35..fc5455672 100644 --- a/pilot/vector_store/weaviate_store.py +++ b/pilot/vector_store/weaviate_store.py @@ -1,16 +1,22 @@ import os import json import weaviate +from langchain.schema import Document from langchain.vectorstores import Weaviate +from weaviate.exceptions import WeaviateBaseError + +from pilot.configs.config import Config from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.logs import logger from pilot.vector_store.vector_store_base import VectorStoreBase +CFG = Config() + class WeaviateStore(VectorStoreBase): """Weaviate database""" - def __init__(self, ctx: dict, weaviate_url: str) -> None: + def __init__(self, ctx: dict) -> None: """Initialize with Weaviate client.""" try: import weaviate @@ -21,9 +27,11 @@ class WeaviateStore(VectorStoreBase): ) self.ctx = ctx - self.weaviate_url = weaviate_url + self.weaviate_url = CFG.WEAVIATE_URL + self.embedding = ctx.get("embeddings", None) + self.vector_name = ctx["vector_store_name"] self.persist_dir = os.path.join( - KNOWLEDGE_UPLOAD_ROOT_PATH, ctx["vector_store_name"] + ".vectordb" + KNOWLEDGE_UPLOAD_ROOT_PATH, self.vector_name + ".vectordb" ) self.vector_store_client = weaviate.Client(self.weaviate_url) @@ -31,28 +39,41 @@ class WeaviateStore(VectorStoreBase): def similar_search(self, text: str, topk: int) -> None: """Perform similar search in Weaviate""" logger.info("Weaviate similar search") - nearText = { - "concepts": [text], - "distance": 0.75, # prior to v1.14 use "certainty" instead of "distance" - } + # nearText = { + # "concepts": [text], + # "distance": 0.75, # prior to v1.14 use "certainty" instead of "distance" + # } + # vector = self.embedding.embed_query(text) response = ( - self.vector_store_client.query.get("Document", ["metadata", "text"]) - .with_near_vector({"vector": nearText}) - .with_limit(topk) - .with_additional(["distance"]) - .do() + self.vector_store_client.query.get( + self.vector_name, ["metadata", "page_content"] + ) + # .with_near_vector({"vector": vector}) + .with_limit(topk).do() ) - - return json.dumps(response, indent=2) + res = response["data"]["Get"][list(response["data"]["Get"].keys())[0]] + docs = [] + for r in res: + docs.append( + Document( + page_content=r["page_content"], + metadata={"metadata": r["metadata"]}, + ) + ) + return docs def vector_name_exists(self) -> bool: """Check if a vector name exists for a given class in Weaviate. Returns: bool: True if the vector name exists, False otherwise. """ - if self.vector_store_client.schema.get("Document"): - return True - return False + try: + if self.vector_store_client.schema.get(self.vector_name): + return True + return False + except WeaviateBaseError as e: + logger.error("vector_name_exists error", e.message) + return False def _default_schema(self) -> None: """ @@ -62,39 +83,39 @@ class WeaviateStore(VectorStoreBase): schema = { "classes": [ { - "class": "Document", + "class": self.vector_name, "description": "A document with metadata and text", - "moduleConfig": { - "text2vec-transformers": { - "poolingStrategy": "masked_mean", - "vectorizeClassName": False, - } - }, + # "moduleConfig": { + # "text2vec-transformers": { + # "poolingStrategy": "masked_mean", + # "vectorizeClassName": False, + # } + # }, "properties": [ { "dataType": ["text"], - "moduleConfig": { - "text2vec-transformers": { - "skip": False, - "vectorizePropertyName": False, - } - }, + # "moduleConfig": { + # "text2vec-transformers": { + # "skip": False, + # "vectorizePropertyName": False, + # } + # }, "description": "Metadata of the document", "name": "metadata", }, { "dataType": ["text"], - "moduleConfig": { - "text2vec-transformers": { - "skip": False, - "vectorizePropertyName": False, - } - }, + # "moduleConfig": { + # "text2vec-transformers": { + # "skip": False, + # "vectorizePropertyName": False, + # } + # }, "description": "Text content of the document", - "name": "text", + "name": "page_content", }, ], - "vectorizer": "text2vec-transformers", + # "vectorizer": "text2vec-transformers", } ] } @@ -114,6 +135,12 @@ class WeaviateStore(VectorStoreBase): # Batch import all documents for i in range(len(texts)): - properties = {"metadata": metadatas[i], "text": texts[i]} + properties = { + "metadata": metadatas[i]["source"], + "page_content": texts[i], + } - self.vector_store_client.batch.add_data_object(properties, "Document") + self.vector_store_client.batch.add_data_object( + data_object=properties, class_name=self.vector_name + ) + self.vector_store_client.batch.flush() diff --git a/requirements.txt b/requirements.txt index 555592f98..594d5bfae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -59,12 +59,13 @@ nltk python-dotenv==1.0.0 # pymilvus==2.2.1 vcrpy -chromadb +chromadb=0.3.22 markdown2 colorama playsound distro pypdf +weaviate-client # Testing dependencies pytest