diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index 8cd915d90..71a310bc3 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -1,16 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import json -import os +from typing import List -from bs4 import BeautifulSoup -from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader -from langchain.vectorstores import Milvus, Chroma -from pymilvus import connections - -from pilot.server.vicuna_server import embeddings -from pilot.source_embedding.text_to_vector import TextToVector -# from vector_store import ESVectorStore +from langchain.document_loaders import PyPDFLoader +from langchain.schema import Document from pilot.source_embedding import SourceEmbedding, register @@ -19,7 +12,7 @@ class PDFEmbedding(SourceEmbedding): """yuque embedding for read yuque document.""" def __init__(self, file_path, model_name, vector_store_config): - """Initialize with YuqueLoader url.""" + """Initialize with pdf path.""" self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config @@ -27,28 +20,16 @@ class PDFEmbedding(SourceEmbedding): @register def read(self): """Load from pdf path.""" - docs = [] - # loader = UnstructuredFileLoader(self.file_path) - loader = UnstructuredPDFLoader(self.file_path, mode="elements") - return loader.load()[0] + loader = PyPDFLoader(self.file_path) + return loader.load() @register - def text_to_vector(self, docs): - """Load from yuque url.""" - for doc in docs: - doc["vector"] = TextToVector.textToVector(doc["content"])[0] - return docs + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + documents[i].page_content = d.page_content.replace(" ", "").replace("\n", "") + i += 1 + return documents - @register - def index_to_store(self, docs): - """index into vector store.""" - # vector_db = Milvus.add_texts( - # docs, - # embeddings, - # connection_args={"host": "127.0.0.1", "port": "19530"}, - # ) - db = Chroma.from_documents(docs, embeddings) - - return Chroma.from_documents(docs, embeddings) diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index 05e8de338..ec66e302c 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -1,11 +1,10 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - +import os from abc import ABC, abstractmethod -from pymilvus import connections, FieldSchema, DataType, CollectionSchema - -from pilot.source_embedding.text_to_vector import TextToVector +from langchain.embeddings import HuggingFaceEmbeddings +from langchain.vectorstores import Chroma from typing import List @@ -30,9 +29,6 @@ class SourceEmbedding(ABC): self.model_name = model_name self.vector_store_config = vector_store_config - # Sub-classes should implement this method - # as return list(self.lazy_load()). - # This method returns a List which is materialized in memory. @abstractmethod @register def read(self) -> List[ABC]: @@ -49,61 +45,23 @@ class SourceEmbedding(ABC): @register def text_to_vector(self, docs): """transform vector""" - for doc in docs: - doc["vector"] = TextToVector.textToVector(doc["content"])[0] - return docs + pass @register - def index_to_store(self): + def index_to_store(self, docs): """index to vector store""" - milvus = connections.connect( - alias="default", - host='localhost', - port="19530" - ) - doc_id = FieldSchema( - name="doc_id", - dtype=DataType.INT64, - is_primary=True, - ) - doc_vector = FieldSchema( - name="doc_vector", - dtype=DataType.FLOAT_VECTOR, - dim=self.vector_store_config["dim"] - ) - schema = CollectionSchema( - fields=[doc_id, doc_vector], - description=self.vector_store_config["description"] - ) + embeddings = HuggingFaceEmbeddings(model_name=self.model_name) - @register - def index_to_store(self): - """index to vector store""" - milvus = connections.connect( - alias="default", - host='localhost', - port="19530" - ) - doc_id = FieldSchema( - name="doc_id", - dtype=DataType.INT64, - is_primary=True, - ) - doc_vector = FieldSchema( - name="doc_vector", - dtype=DataType.FLOAT_VECTOR, - dim=self.vector_store_config["dim"] - ) - schema = CollectionSchema( - fields=[doc_id, doc_vector], - description=self.vector_store_config["description"] - ) + persist_dir = os.path.join(self.vector_store_config["vector_store_path"], + self.vector_store_config["vector_store_name"] + ".vectordb") + vector_store = Chroma.from_documents(docs, embeddings, persist_directory=persist_dir) + vector_store.persist() def source_embedding(self): if 'read' in registered_methods: text = self.read() - if 'process' in registered_methods: - self.process(text) + if 'data_process' in registered_methods: + text = self.data_process(text) if 'text_split' in registered_methods: self.text_split(text) if 'text_to_vector' in registered_methods: