diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 9e6542db9..1c4a52c35 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -150,6 +150,9 @@ class Config(metaclass=Singleton): self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) + self.WEAVIATE_URL = os.getenv("WEAVIATE_URL", "http://127.0.0.1:8080") + + # QLoRA self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index a56b2a098..f9f27f603 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -53,7 +53,7 @@ class ChatNewKnowledge(BaseChat): docs = self.knowledge_embedding_client.similar_search( self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE ) - context = [d.page_content for d in docs] + context = [d["page_content"] for d in docs] context = context[:2000] input_values = {"context": context, "question": self.current_user_input} return input_values diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index 6672d3d23..8ba6df253 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -1,8 +1,9 @@ from pilot.vector_store.chroma_store import ChromaStore -# from pilot.vector_store.milvus_store import MilvusStore +from pilot.vector_store.milvus_store import MilvusStore +from pilot.vector_store.weaviate_store import WeaviateStore -connector = {"Chroma": ChromaStore, "Milvus": None} +connector = {"Chroma": ChromaStore, "Milvus": MilvusStore, "Weaviate": WeaviateStore} class VectorStoreConnector: diff --git a/pilot/vector_store/weaviate_store.py b/pilot/vector_store/weaviate_store.py index e208dde35..f4e35bcfd 100644 --- a/pilot/vector_store/weaviate_store.py +++ b/pilot/vector_store/weaviate_store.py @@ -2,15 +2,19 @@ import os import json import weaviate from langchain.vectorstores import Weaviate + +from pilot.configs.config import Config from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.logs import logger from pilot.vector_store.vector_store_base import VectorStoreBase +CFG = Config() + class WeaviateStore(VectorStoreBase): """Weaviate database""" - def __init__(self, ctx: dict, weaviate_url: str) -> None: + def __init__(self, ctx: dict) -> None: """Initialize with Weaviate client.""" try: import weaviate @@ -21,9 +25,11 @@ class WeaviateStore(VectorStoreBase): ) self.ctx = ctx - self.weaviate_url = weaviate_url + self.weaviate_url = CFG.WEAVIATE_URL + self.embedding = ctx.get("embeddings", None) + self.vector_name = ctx["vector_store_name"] self.persist_dir = os.path.join( - KNOWLEDGE_UPLOAD_ROOT_PATH, ctx["vector_store_name"] + ".vectordb" + KNOWLEDGE_UPLOAD_ROOT_PATH, self.vector_name + ".vectordb" ) self.vector_store_client = weaviate.Client(self.weaviate_url) @@ -31,26 +37,26 @@ class WeaviateStore(VectorStoreBase): def similar_search(self, text: str, topk: int) -> None: """Perform similar search in Weaviate""" logger.info("Weaviate similar search") - nearText = { - "concepts": [text], - "distance": 0.75, # prior to v1.14 use "certainty" instead of "distance" - } + # nearText = { + # "concepts": [text], + # "distance": 0.75, # prior to v1.14 use "certainty" instead of "distance" + # } + # vector = self.embedding.embed_query(text) response = ( - self.vector_store_client.query.get("Document", ["metadata", "text"]) - .with_near_vector({"vector": nearText}) + self.vector_store_client.query.get(self.vector_name, ["metadata", "page_content"]) + # .with_near_vector({"vector": vector}) .with_limit(topk) - .with_additional(["distance"]) .do() ) - - return json.dumps(response, indent=2) + docs = response['data']['Get'][list(response['data']['Get'].keys())[0]] + return docs def vector_name_exists(self) -> bool: """Check if a vector name exists for a given class in Weaviate. Returns: bool: True if the vector name exists, False otherwise. """ - if self.vector_store_client.schema.get("Document"): + if self.vector_store_client.schema.get(self.vector_name): return True return False @@ -62,39 +68,39 @@ class WeaviateStore(VectorStoreBase): schema = { "classes": [ { - "class": "Document", + "class": self.vector_name, "description": "A document with metadata and text", - "moduleConfig": { - "text2vec-transformers": { - "poolingStrategy": "masked_mean", - "vectorizeClassName": False, - } - }, + # "moduleConfig": { + # "text2vec-transformers": { + # "poolingStrategy": "masked_mean", + # "vectorizeClassName": False, + # } + # }, "properties": [ { "dataType": ["text"], - "moduleConfig": { - "text2vec-transformers": { - "skip": False, - "vectorizePropertyName": False, - } - }, + # "moduleConfig": { + # "text2vec-transformers": { + # "skip": False, + # "vectorizePropertyName": False, + # } + # }, "description": "Metadata of the document", "name": "metadata", }, { "dataType": ["text"], - "moduleConfig": { - "text2vec-transformers": { - "skip": False, - "vectorizePropertyName": False, - } - }, + # "moduleConfig": { + # "text2vec-transformers": { + # "skip": False, + # "vectorizePropertyName": False, + # } + # }, "description": "Text content of the document", - "name": "text", - }, + "name": "page_content", + } ], - "vectorizer": "text2vec-transformers", + # "vectorizer": "text2vec-transformers", } ] } @@ -114,6 +120,7 @@ class WeaviateStore(VectorStoreBase): # Batch import all documents for i in range(len(texts)): - properties = {"metadata": metadatas[i], "text": texts[i]} + properties = {"metadata": metadatas[i]['source'], "page_content": texts[i]} - self.vector_store_client.batch.add_data_object(properties, "Document") + self.vector_store_client.batch.add_data_object(data_object=properties, class_name=self.vector_name) + self.vector_store_client.batch.flush() diff --git a/requirements.txt b/requirements.txt index 555592f98..594d5bfae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -59,12 +59,13 @@ nltk python-dotenv==1.0.0 # pymilvus==2.2.1 vcrpy -chromadb +chromadb=0.3.22 markdown2 colorama playsound distro pypdf +weaviate-client # Testing dependencies pytest