mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-22 01:22:34 +00:00
init doc2vector
This commit is contained in:
@@ -1,28 +1,61 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import copy
|
||||
from typing import Optional, List, Dict
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader
|
||||
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader
|
||||
from langchain.chains import VectorDBQA
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG
|
||||
|
||||
VECTOR_SEARCH_TOP_K = 5
|
||||
|
||||
class BaseKnownLedgeQA:
|
||||
|
||||
llm: object = None
|
||||
embeddings: object = None
|
||||
class KnownLedge2Vector:
|
||||
|
||||
embeddings: object = None
|
||||
model_name = LLM_MODEL_CONFIG["sentence-transforms"]
|
||||
top_k: int = VECTOR_SEARCH_TOP_K
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, model_name=None) -> None:
|
||||
if not model_name:
|
||||
# use default embedding model
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
|
||||
def init_vector_store(self):
|
||||
pass
|
||||
documents = self.load_knownlege()
|
||||
persist_dir = os.path.join(VECTORE_PATH, ".vectordb")
|
||||
if os.path.exists(persist_dir):
|
||||
# 从本地持久化文件中Load
|
||||
pass
|
||||
else:
|
||||
# 重新初始化
|
||||
vector_store = Chroma.from_documents(documents=documents,
|
||||
embedding=self.embeddings,
|
||||
persist_directory=persist_dir)
|
||||
vector_store.persist()
|
||||
|
||||
return persist_dir
|
||||
|
||||
def load_knownlege(self):
|
||||
pass
|
||||
docments = []
|
||||
for root, _, files in os.walk(DATASETS_DIR, topdown=False):
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
print(filename)
|
||||
docs = self._load_file(filename)
|
||||
# 更新metadata数据
|
||||
new_docs = []
|
||||
for doc in docs:
|
||||
doc.metadata = {"source": doc.metadata["source"].replace(DATASETS_DIR, "")}
|
||||
print("文档2向量初始化中, 请稍等...", doc.metadata)
|
||||
new_docs.append(doc)
|
||||
docments += docs
|
||||
|
||||
return docments
|
||||
|
||||
def _load_file(self, filename):
|
||||
# 加载文件
|
||||
@@ -38,4 +71,9 @@ class BaseKnownLedgeQA:
|
||||
|
||||
def _load_from_url(self, url):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
k2v = KnownLedge2Vector()
|
||||
k2v.load_knownlege()
|
||||
|
Reference in New Issue
Block a user