Files
DB-GPT/pilot/source_embedding/source_embedding.py
chenketing 365319a86c embedding
2023-05-10 20:58:35 +08:00

113 lines
3.2 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
from pymilvus import connections, FieldSchema, DataType, CollectionSchema
from pilot.source_embedding.text_to_vector import TextToVector
from typing import List
registered_methods = []
def register(method):
registered_methods.append(method.__name__)
return method
class SourceEmbedding(ABC):
"""base class for read data source embedding pipeline.
include data read, data process, data split, data to vector, data index vector store
Implementations should implement the method
"""
def __init__(self, yuque_path, model_name, vector_store_config):
"""Initialize with YuqueLoader url, model_name, vector_store_config"""
self.yuque_path = yuque_path
self.model_name = model_name
self.vector_store_config = vector_store_config
# Sub-classes should implement this method
# as return list(self.lazy_load()).
# This method returns a List which is materialized in memory.
@abstractmethod
@register
def read(self) -> List[ABC]:
"""read datasource into document objects."""
@register
def data_process(self, text):
"""pre process data."""
@register
def text_split(self, text):
"""text split chunk"""
pass
@register
def text_to_vector(self, docs):
"""transform vector"""
for doc in docs:
doc["vector"] = TextToVector.textToVector(doc["content"])[0]
return docs
@register
def index_to_store(self):
"""index to vector store"""
milvus = connections.connect(
alias="default",
host='localhost',
port="19530"
)
doc_id = FieldSchema(
name="doc_id",
dtype=DataType.INT64,
is_primary=True,
)
doc_vector = FieldSchema(
name="doc_vector",
dtype=DataType.FLOAT_VECTOR,
dim=self.vector_store_config["dim"]
)
schema = CollectionSchema(
fields=[doc_id, doc_vector],
description=self.vector_store_config["description"]
)
@register
def index_to_store(self):
"""index to vector store"""
milvus = connections.connect(
alias="default",
host='localhost',
port="19530"
)
doc_id = FieldSchema(
name="doc_id",
dtype=DataType.INT64,
is_primary=True,
)
doc_vector = FieldSchema(
name="doc_vector",
dtype=DataType.FLOAT_VECTOR,
dim=self.vector_store_config["dim"]
)
schema = CollectionSchema(
fields=[doc_id, doc_vector],
description=self.vector_store_config["description"]
)
def source_embedding(self):
if 'read' in registered_methods:
text = self.read()
if 'process' in registered_methods:
self.process(text)
if 'text_split' in registered_methods:
self.text_split(text)
if 'text_to_vector' in registered_methods:
self.text_to_vector(text)
if 'index_to_store' in registered_methods:
self.index_to_store(text)