fix: Weaviate document format.

1.similar search: docs format
2.conf SUMMARY_CONFIG
This commit is contained in:
aries-ckt 2023-06-19 16:44:18 +08:00
parent b95084b89f
commit 05a74d89cd
3 changed files with 19 additions and 6 deletions

View File

@ -161,7 +161,7 @@ class Config(metaclass=Singleton):
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
### SUMMARY_CONFIG Configuration
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR")
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""

View File

@ -53,7 +53,8 @@ class ChatNewKnowledge(BaseChat):
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
)
context = [d["page_content"] for d in docs]
context = [d.page_content for d in docs]
self.metadata = [d.metadata for d in docs]
context = context[:2000]
input_values = {"context": context, "question": self.current_user_input}
return input_values

View File

@ -1,7 +1,9 @@
import os
import json
import weaviate
from langchain.schema import Document
from langchain.vectorstores import Weaviate
from weaviate.exceptions import WeaviateBaseError
from pilot.configs.config import Config
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
@ -48,7 +50,13 @@ class WeaviateStore(VectorStoreBase):
.with_limit(topk)
.do()
)
docs = response['data']['Get'][list(response['data']['Get'].keys())[0]]
res = response['data']['Get'][list(response['data']['Get'].keys())[0]]
docs = []
for r in res:
docs.append(Document(
page_content=r['page_content'],
metadata={"metadata": r['metadata']},
))
return docs
def vector_name_exists(self) -> bool:
@ -56,9 +64,13 @@ class WeaviateStore(VectorStoreBase):
Returns:
bool: True if the vector name exists, False otherwise.
"""
if self.vector_store_client.schema.get(self.vector_name):
return True
return False
try:
if self.vector_store_client.schema.get(self.vector_name):
return True
return False
except WeaviateBaseError as e:
logger.error("vector_name_exists error", e.message)
return False
def _default_schema(self) -> None:
"""