mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 20:47:46 +00:00
embedding
This commit is contained in:
parent
1d2083063c
commit
365319a86c
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,6 +6,7 @@ __pycache__/
|
|||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
|
.idea
|
||||||
.vscode
|
.vscode
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
.Python
|
.Python
|
||||||
|
@ -1 +1,11 @@
|
|||||||
__version__ = "0.0.1"
|
from pilot.source_embedding import (SourceEmbedding, register)
|
||||||
|
from pilot.source_embedding import TextToVector
|
||||||
|
from pilot.source_embedding import Text2Vectors
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SourceEmbedding",
|
||||||
|
"TextToVector",
|
||||||
|
"Text2Vectors",
|
||||||
|
"register"
|
||||||
|
]
|
17
pilot/source_embedding/Text2Vectors.py
Normal file
17
pilot/source_embedding/Text2Vectors.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from typing import List
|
||||||
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Text2Vectors(Embeddings):
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Embed search docs."""
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
hfemb = HuggingFaceEmbeddings(model_name="/Users/chenketing/Desktop/project/all-MiniLM-L6-v2")
|
||||||
|
return hfemb.embed_documents(text)[0]
|
12
pilot/source_embedding/__init__.py
Normal file
12
pilot/source_embedding/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from pilot.source_embedding.source_embedding import SourceEmbedding
|
||||||
|
from pilot.source_embedding.source_embedding import register
|
||||||
|
from pilot.source_embedding.text_to_vector import TextToVector
|
||||||
|
from pilot.source_embedding.Text2Vectors import Text2Vectors
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SourceEmbedding",
|
||||||
|
"TextToVector",
|
||||||
|
"Text2Vectors",
|
||||||
|
"register"
|
||||||
|
]
|
14
pilot/source_embedding/chroma_test.py
Normal file
14
pilot/source_embedding/chroma_test.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from langchain.document_loaders import UnstructuredFileLoader
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
|
||||||
|
from pilot import TextToVector
|
||||||
|
|
||||||
|
path="/Users/chenketing/Downloads/OceanBase-数据库-V4.1.0-OceanBase-介绍.pdf"
|
||||||
|
|
||||||
|
|
||||||
|
loader = UnstructuredFileLoader(path)
|
||||||
|
text_splitor = CharacterTextSplitter()
|
||||||
|
docs = loader.load_and_split(text_splitor)
|
||||||
|
|
||||||
|
|
||||||
|
# doc["vector"] = TextToVector.textToVector(doc["content"])[0]
|
54
pilot/source_embedding/pdf_embedding.py
Normal file
54
pilot/source_embedding/pdf_embedding.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader
|
||||||
|
from langchain.vectorstores import Milvus, Chroma
|
||||||
|
from pymilvus import connections
|
||||||
|
|
||||||
|
from pilot.server.vicuna_server import embeddings
|
||||||
|
from pilot.source_embedding.text_to_vector import TextToVector
|
||||||
|
# from vector_store import ESVectorStore
|
||||||
|
|
||||||
|
from pilot.source_embedding import SourceEmbedding, register
|
||||||
|
|
||||||
|
|
||||||
|
class PDFEmbedding(SourceEmbedding):
|
||||||
|
"""yuque embedding for read yuque document."""
|
||||||
|
|
||||||
|
def __init__(self, file_path, model_name, vector_store_config):
|
||||||
|
"""Initialize with YuqueLoader url."""
|
||||||
|
self.file_path = file_path
|
||||||
|
self.model_name = model_name
|
||||||
|
self.vector_store_config = vector_store_config
|
||||||
|
|
||||||
|
@register
|
||||||
|
def read(self):
|
||||||
|
"""Load from pdf path."""
|
||||||
|
docs = []
|
||||||
|
# loader = UnstructuredFileLoader(self.file_path)
|
||||||
|
loader = UnstructuredPDFLoader(self.file_path, mode="elements")
|
||||||
|
return loader.load()[0]
|
||||||
|
|
||||||
|
@register
|
||||||
|
def text_to_vector(self, docs):
|
||||||
|
"""Load from yuque url."""
|
||||||
|
for doc in docs:
|
||||||
|
doc["vector"] = TextToVector.textToVector(doc["content"])[0]
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@register
|
||||||
|
def index_to_store(self, docs):
|
||||||
|
"""index into vector store."""
|
||||||
|
|
||||||
|
# vector_db = Milvus.add_texts(
|
||||||
|
# docs,
|
||||||
|
# embeddings,
|
||||||
|
# connection_args={"host": "127.0.0.1", "port": "19530"},
|
||||||
|
# )
|
||||||
|
db = Chroma.from_documents(docs, embeddings)
|
||||||
|
|
||||||
|
return Chroma.from_documents(docs, embeddings)
|
||||||
|
|
53
pilot/source_embedding/search_milvus.py
Normal file
53
pilot/source_embedding/search_milvus.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from langchain.vectorstores import Milvus
|
||||||
|
from pymilvus import Collection,utility
|
||||||
|
from pymilvus import connections, DataType, FieldSchema, CollectionSchema
|
||||||
|
from pilot.source_embedding.Text2Vectors import Text2Vectors
|
||||||
|
|
||||||
|
# milvus = connections.connect(
|
||||||
|
# alias="default",
|
||||||
|
# host='localhost',
|
||||||
|
# port="19530"
|
||||||
|
# )
|
||||||
|
# collection = Collection("book")
|
||||||
|
|
||||||
|
|
||||||
|
# Get an existing collection.
|
||||||
|
# collection.load()
|
||||||
|
#
|
||||||
|
# search_params = {"metric_type": "L2", "params": {}, "offset": 5}
|
||||||
|
#
|
||||||
|
# results = collection.search(
|
||||||
|
# data=[[0.1, 0.2]],
|
||||||
|
# anns_field="book_intro",
|
||||||
|
# param=search_params,
|
||||||
|
# limit=10,
|
||||||
|
# expr=None,
|
||||||
|
# output_fields=['book_id'],
|
||||||
|
# consistency_level="Strong"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# # get the IDs of all returned hits
|
||||||
|
# results[0].ids
|
||||||
|
#
|
||||||
|
# # get the distances to the query vector from all returned hits
|
||||||
|
# results[0].distances
|
||||||
|
#
|
||||||
|
# # get the value of an output field specified in the search request.
|
||||||
|
# # vector fields are not supported yet.
|
||||||
|
# hit = results[0][0]
|
||||||
|
# hit.entity.get('title')
|
||||||
|
|
||||||
|
milvus = connections.connect(
|
||||||
|
alias="default",
|
||||||
|
host='localhost',
|
||||||
|
port="19530"
|
||||||
|
)
|
||||||
|
data = ["aaa", "bbb"]
|
||||||
|
text_embeddings = Text2Vectors()
|
||||||
|
mivuls = Milvus(collection_name='document', embedding_function= text_embeddings, connection_args={"host": "127.0.0.1", "port": "19530", "alias":"default"}, text_field="")
|
||||||
|
|
||||||
|
mivuls.from_texts(texts=data, embedding=text_embeddings)
|
||||||
|
# docs,
|
||||||
|
# embedding=embeddings,
|
||||||
|
# connection_args={"host": "127.0.0.1", "port": "19530", "alias": "default"}
|
||||||
|
# )
|
112
pilot/source_embedding/source_embedding.py
Normal file
112
pilot/source_embedding/source_embedding.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from pymilvus import connections, FieldSchema, DataType, CollectionSchema
|
||||||
|
|
||||||
|
from pilot.source_embedding.text_to_vector import TextToVector
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
registered_methods = []
|
||||||
|
|
||||||
|
|
||||||
|
def register(method):
|
||||||
|
registered_methods.append(method.__name__)
|
||||||
|
return method
|
||||||
|
|
||||||
|
|
||||||
|
class SourceEmbedding(ABC):
|
||||||
|
"""base class for read data source embedding pipeline.
|
||||||
|
include data read, data process, data split, data to vector, data index vector store
|
||||||
|
Implementations should implement the method
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, yuque_path, model_name, vector_store_config):
|
||||||
|
"""Initialize with YuqueLoader url, model_name, vector_store_config"""
|
||||||
|
self.yuque_path = yuque_path
|
||||||
|
self.model_name = model_name
|
||||||
|
self.vector_store_config = vector_store_config
|
||||||
|
|
||||||
|
# Sub-classes should implement this method
|
||||||
|
# as return list(self.lazy_load()).
|
||||||
|
# This method returns a List which is materialized in memory.
|
||||||
|
@abstractmethod
|
||||||
|
@register
|
||||||
|
def read(self) -> List[ABC]:
|
||||||
|
"""read datasource into document objects."""
|
||||||
|
@register
|
||||||
|
def data_process(self, text):
|
||||||
|
"""pre process data."""
|
||||||
|
|
||||||
|
@register
|
||||||
|
def text_split(self, text):
|
||||||
|
"""text split chunk"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@register
|
||||||
|
def text_to_vector(self, docs):
|
||||||
|
"""transform vector"""
|
||||||
|
for doc in docs:
|
||||||
|
doc["vector"] = TextToVector.textToVector(doc["content"])[0]
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@register
|
||||||
|
def index_to_store(self):
|
||||||
|
"""index to vector store"""
|
||||||
|
milvus = connections.connect(
|
||||||
|
alias="default",
|
||||||
|
host='localhost',
|
||||||
|
port="19530"
|
||||||
|
)
|
||||||
|
doc_id = FieldSchema(
|
||||||
|
name="doc_id",
|
||||||
|
dtype=DataType.INT64,
|
||||||
|
is_primary=True,
|
||||||
|
)
|
||||||
|
doc_vector = FieldSchema(
|
||||||
|
name="doc_vector",
|
||||||
|
dtype=DataType.FLOAT_VECTOR,
|
||||||
|
dim=self.vector_store_config["dim"]
|
||||||
|
)
|
||||||
|
schema = CollectionSchema(
|
||||||
|
fields=[doc_id, doc_vector],
|
||||||
|
description=self.vector_store_config["description"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@register
|
||||||
|
def index_to_store(self):
|
||||||
|
"""index to vector store"""
|
||||||
|
milvus = connections.connect(
|
||||||
|
alias="default",
|
||||||
|
host='localhost',
|
||||||
|
port="19530"
|
||||||
|
)
|
||||||
|
doc_id = FieldSchema(
|
||||||
|
name="doc_id",
|
||||||
|
dtype=DataType.INT64,
|
||||||
|
is_primary=True,
|
||||||
|
)
|
||||||
|
doc_vector = FieldSchema(
|
||||||
|
name="doc_vector",
|
||||||
|
dtype=DataType.FLOAT_VECTOR,
|
||||||
|
dim=self.vector_store_config["dim"]
|
||||||
|
)
|
||||||
|
schema = CollectionSchema(
|
||||||
|
fields=[doc_id, doc_vector],
|
||||||
|
description=self.vector_store_config["description"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def source_embedding(self):
|
||||||
|
if 'read' in registered_methods:
|
||||||
|
text = self.read()
|
||||||
|
if 'process' in registered_methods:
|
||||||
|
self.process(text)
|
||||||
|
if 'text_split' in registered_methods:
|
||||||
|
self.text_split(text)
|
||||||
|
if 'text_to_vector' in registered_methods:
|
||||||
|
self.text_to_vector(text)
|
||||||
|
if 'index_to_store' in registered_methods:
|
||||||
|
self.index_to_store(text)
|
18
pilot/source_embedding/text_to_vector.py
Normal file
18
pilot/source_embedding/text_to_vector.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
class TextToVector:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def textToVector(text):
|
||||||
|
hfemb = HuggingFaceEmbeddings(model_name="/Users/chenketing/Desktop/project/all-MiniLM-L6-v2")
|
||||||
|
return hfemb.embed_documents([text])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def textlist_to_vector(textlist):
|
||||||
|
hfemb = HuggingFaceEmbeddings(model_name="/Users/chenketing/Desktop/project/all-MiniLM-L6-v2")
|
||||||
|
return hfemb.embed_documents(textlist)
|
108
pilot/source_embedding/url_embedding.py
Normal file
108
pilot/source_embedding/url_embedding.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
from random import random
|
||||||
|
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
from langchain.vectorstores import Milvus
|
||||||
|
from langchain.document_loaders import WebBaseLoader
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
from pymilvus import connections, DataType, FieldSchema, CollectionSchema
|
||||||
|
from pymilvus import Collection
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from pilot.source_embedding.text_to_vector import TextToVector
|
||||||
|
|
||||||
|
|
||||||
|
loader = WebBaseLoader([
|
||||||
|
"https://milvus.io/docs/overview.md",
|
||||||
|
])
|
||||||
|
|
||||||
|
docs = loader.load()
|
||||||
|
|
||||||
|
# Split the documents into smaller chunks
|
||||||
|
# text_splitter = CharacterTextSplitter(chunk_size=1024, chunk_overlap=0)
|
||||||
|
# docs = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
|
embeddings = TextToVector.textToVector(docs[0].page_content)
|
||||||
|
|
||||||
|
milvus = connections.connect(
|
||||||
|
alias="default",
|
||||||
|
host='localhost',
|
||||||
|
port="19530"
|
||||||
|
)
|
||||||
|
|
||||||
|
# collection = Collection("test_book")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# data = [{"doc_id": 11011, "content": 11011, "title": 11011, "vector": embeddings[0]}]
|
||||||
|
# # collection = Collection("document")
|
||||||
|
#
|
||||||
|
# # collection.insert(data=data)
|
||||||
|
# entities = [
|
||||||
|
# {
|
||||||
|
# 'doc_id': d['doc_id'],
|
||||||
|
# 'vector': d['vector'],
|
||||||
|
# 'content': d['content'],
|
||||||
|
# 'title': d['titlseae'],
|
||||||
|
# "type": DataType.FLOAT_VECTOR
|
||||||
|
# } for d in data
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# milvus.insert(collection_name="document", entities=entities)
|
||||||
|
# print("success")
|
||||||
|
# 定义集合的字段
|
||||||
|
# fields = [
|
||||||
|
# FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR),
|
||||||
|
# FieldSchema(name="age", dtype=DataType.INT32),
|
||||||
|
# FieldSchema(name="gender", dtype=DataType.STRING),
|
||||||
|
# FieldSchema(name="id", dtype=DataType.INT64) # 添加主键字段
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# book_id = FieldSchema(
|
||||||
|
# name="book_id",
|
||||||
|
# dtype=DataType.INT64,
|
||||||
|
# is_primary=True,
|
||||||
|
# )
|
||||||
|
# book_name = FieldSchema(
|
||||||
|
# name="book_name",
|
||||||
|
# dtype=DataType.BINARY_VECTOR,
|
||||||
|
# max_length=200,
|
||||||
|
# )
|
||||||
|
# word_count = FieldSchema(
|
||||||
|
# name="word_count",
|
||||||
|
# dtype=DataType.INT64,
|
||||||
|
# )
|
||||||
|
# book_intro = FieldSchema(
|
||||||
|
# name="book_intro",
|
||||||
|
# dtype=DataType.FLOAT_VECTOR,
|
||||||
|
# dim=2
|
||||||
|
# )
|
||||||
|
# schema = CollectionSchema(
|
||||||
|
# fields=[book_id, book_name, word_count, book_intro],
|
||||||
|
# description="Test book search"
|
||||||
|
# )
|
||||||
|
collection_name = "test_book"
|
||||||
|
|
||||||
|
collection = Collection(
|
||||||
|
name=collection_name,
|
||||||
|
schema=schema,
|
||||||
|
using='default',
|
||||||
|
shards_num=2
|
||||||
|
)
|
||||||
|
# 插入数据
|
||||||
|
# entities = [[
|
||||||
|
# {"book_id": 30, "book_intro": [0.1, 0.2], "word_count": 1},
|
||||||
|
# {"book_id": 25, "book_intro": [0.1, 0.2], "word_count": 2},
|
||||||
|
# {"book_id": 40, "book_intro": [0.1, 0.2], "word_count": 3}
|
||||||
|
# ]]
|
||||||
|
|
||||||
|
entities = [[30, 25, 40], ["test1", "test2", "test3"], [1, 2, 3], [[0.1, 0.2], [0.1, 0.2], [0.1, 0.2]]]
|
||||||
|
|
||||||
|
collection.insert(entities)
|
||||||
|
print("success")
|
||||||
|
|
||||||
|
# vector_store = Milvus.from_documents(
|
||||||
|
# docs,
|
||||||
|
# embedding=embeddings,
|
||||||
|
# connection_args={"host": "127.0.0.1", "port": "19530", "alias": "default"}
|
||||||
|
# )
|
Loading…
Reference in New Issue
Block a user