mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-30 05:49:25 +00:00
79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
|
|
from langchain.embeddings import HuggingFaceEmbeddings
|
|
from langchain.vectorstores import Chroma
|
|
|
|
from typing import List, Optional, Dict
|
|
|
|
registered_methods = []
|
|
|
|
|
|
def register(method):
|
|
registered_methods.append(method.__name__)
|
|
return method
|
|
|
|
|
|
class SourceEmbedding(ABC):
|
|
"""base class for read data source embedding pipeline.
|
|
include data read, data process, data split, data to vector, data index vector store
|
|
Implementations should implement the method
|
|
"""
|
|
|
|
def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None):
|
|
"""Initialize with Loader url, model_name, vector_store_config"""
|
|
self.file_path = file_path
|
|
self.model_name = model_name
|
|
self.vector_store_config = vector_store_config
|
|
self.embedding_args = embedding_args
|
|
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
|
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
|
|
self.vector_store_config["vector_store_name"] + ".vectordb")
|
|
self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
|
|
|
@abstractmethod
|
|
@register
|
|
def read(self) -> List[ABC]:
|
|
"""read datasource into document objects."""
|
|
@register
|
|
def data_process(self, text):
|
|
"""pre process data."""
|
|
|
|
@register
|
|
def text_split(self, text):
|
|
"""text split chunk"""
|
|
pass
|
|
|
|
@register
|
|
def text_to_vector(self, docs):
|
|
"""transform vector"""
|
|
pass
|
|
|
|
@register
|
|
def index_to_store(self, docs):
|
|
"""index to vector store"""
|
|
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
|
|
self.vector_store_config["vector_store_name"] + ".vectordb")
|
|
self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir)
|
|
self.vector_store.persist()
|
|
|
|
@register
|
|
def similar_search(self, doc, topk):
|
|
"""vector store similarity_search"""
|
|
|
|
return self.vector_store_client.similarity_search(doc, topk)
|
|
|
|
def source_embedding(self):
|
|
if 'read' in registered_methods:
|
|
text = self.read()
|
|
if 'data_process' in registered_methods:
|
|
text = self.data_process(text)
|
|
if 'text_split' in registered_methods:
|
|
self.text_split(text)
|
|
if 'text_to_vector' in registered_methods:
|
|
self.text_to_vector(text)
|
|
if 'index_to_store' in registered_methods:
|
|
self.index_to_store(text)
|