mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 13:40:54 +00:00
update config
This commit is contained in:
@@ -7,11 +7,12 @@ import nltk
|
|||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||||
VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store")
|
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
|
||||||
|
VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
|
||||||
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||||
DATASETS_DIR = os.path.join(ROOT_PATH, "pilot/datasets")
|
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||||
|
|
||||||
nltk.data.path = [os.path.join(ROOT_PATH, "pilot/nltk_data")] + nltk.data.path
|
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||||
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
LLM_MODEL_CONFIG = {
|
LLM_MODEL_CONFIG = {
|
||||||
|
@@ -28,16 +28,18 @@ class KnownLedge2Vector:
|
|||||||
def init_vector_store(self):
|
def init_vector_store(self):
|
||||||
documents = self.load_knownlege()
|
documents = self.load_knownlege()
|
||||||
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
||||||
|
print("向量数据库持久化地址: ", persist_dir)
|
||||||
if os.path.exists(persist_dir):
|
if os.path.exists(persist_dir):
|
||||||
# 从本地持久化文件中Load
|
# 从本地持久化文件中Load
|
||||||
pass
|
vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
||||||
|
vector_store.add_documents(documents=documents)
|
||||||
else:
|
else:
|
||||||
# 重新初始化
|
# 重新初始化
|
||||||
vector_store = Chroma.from_documents(documents=documents,
|
vector_store = Chroma.from_documents(documents=documents,
|
||||||
embedding=self.embeddings,
|
embedding=self.embeddings,
|
||||||
persist_directory=persist_dir)
|
persist_directory=persist_dir)
|
||||||
vector_store.persist()
|
vector_store.persist()
|
||||||
|
vector_store = None
|
||||||
return persist_dir
|
return persist_dir
|
||||||
|
|
||||||
def load_knownlege(self):
|
def load_knownlege(self):
|
||||||
@@ -45,7 +47,6 @@ class KnownLedge2Vector:
|
|||||||
for root, _, files in os.walk(DATASETS_DIR, topdown=False):
|
for root, _, files in os.walk(DATASETS_DIR, topdown=False):
|
||||||
for file in files:
|
for file in files:
|
||||||
filename = os.path.join(root, file)
|
filename = os.path.join(root, file)
|
||||||
print(filename)
|
|
||||||
docs = self._load_file(filename)
|
docs = self._load_file(filename)
|
||||||
# 更新metadata数据
|
# 更新metadata数据
|
||||||
new_docs = []
|
new_docs = []
|
||||||
@@ -53,8 +54,7 @@ class KnownLedge2Vector:
|
|||||||
doc.metadata = {"source": doc.metadata["source"].replace(DATASETS_DIR, "")}
|
doc.metadata = {"source": doc.metadata["source"].replace(DATASETS_DIR, "")}
|
||||||
print("文档2向量初始化中, 请稍等...", doc.metadata)
|
print("文档2向量初始化中, 请稍等...", doc.metadata)
|
||||||
new_docs.append(doc)
|
new_docs.append(doc)
|
||||||
docments += docs
|
docments += new_docs
|
||||||
|
|
||||||
return docments
|
return docments
|
||||||
|
|
||||||
def _load_file(self, filename):
|
def _load_file(self, filename):
|
||||||
@@ -75,5 +75,5 @@ class KnownLedge2Vector:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
k2v = KnownLedge2Vector()
|
k2v = KnownLedge2Vector()
|
||||||
k2v.load_knownlege()
|
k2v.init_vector_store()
|
||||||
|
|
Reference in New Issue
Block a user