diff --git a/pilot/connections/rdbms/py_study/test_cls_1.py b/pilot/connections/rdbms/py_study/test_cls_1.py new file mode 100644 index 000000000..66c07de78 --- /dev/null +++ b/pilot/connections/rdbms/py_study/test_cls_1.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from pydantic import BaseModel +from test_cls_base import TestBase + + +class Test1(TestBase): + + def write(self): + self.test_values.append("x") + self.test_values.append("y") + self.test_values.append("g") + diff --git a/pilot/connections/rdbms/py_study/test_cls_2.py b/pilot/connections/rdbms/py_study/test_cls_2.py new file mode 100644 index 000000000..c0fdbb305 --- /dev/null +++ b/pilot/connections/rdbms/py_study/test_cls_2.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod +from pydantic import BaseModel +from test_cls_base import TestBase +from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union + +class Test2(TestBase): + test_2_values:List = [] + + def write(self): + self.test_values.append(1) + self.test_values.append(2) + self.test_values.append(3) + self.test_2_values.append("x") + self.test_2_values.append("y") + self.test_2_values.append("z") \ No newline at end of file diff --git a/pilot/connections/rdbms/py_study/test_cls_base.py b/pilot/connections/rdbms/py_study/test_cls_base.py new file mode 100644 index 000000000..9a04a48b3 --- /dev/null +++ b/pilot/connections/rdbms/py_study/test_cls_base.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from pydantic import BaseModel +from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union + + +class TestBase(BaseModel, ABC): + test_values: List = [] + + + def test(self): + print(self.__class__.__name__ + ":" ) + print(self.test_values) \ No newline at end of file diff --git a/pilot/prompts/example_base.py b/pilot/prompts/example_base.py new file mode 100644 index 000000000..4d876aa51 --- /dev/null +++ b/pilot/prompts/example_base.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod +from pydantic import BaseModel +from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union + +from pilot.common.schema import ExampleType + +class ExampleSelector(BaseModel, ABC): + examples: List[List] + use_example: bool = False + type: str = ExampleType.ONE_SHOT.value + + def examples(self, count: int = 2): + if ExampleType.ONE_SHOT.value == self.type: + return self.__one_show_context() + else: + return self.__few_shot_context(count) + + def __few_shot_context(self, count: int = 2) -> List[List]: + """ + Use 2 or more examples, default 2 + Returns: example text + """ + if self.use_example: + need_use = self.examples[:count] + return need_use + return None + + def __one_show_context(self) -> List: + """ + Use one examples + Returns: + + """ + if self.use_example: + need_use = self.examples[:1] + return need_use + + return None diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py new file mode 100644 index 000000000..6cd71b39c --- /dev/null +++ b/pilot/scene/chat_execution/example.py @@ -0,0 +1,9 @@ +from pilot.prompts.example_base import ExampleSelector + +## Two examples are defined by default +EXAMPLES = [ + [{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}], + [{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}] +] + +example = ExampleSelector(examples=EXAMPLES, use_example=True) diff --git a/pilot/scene/chat_execution/prompt_v2.py b/pilot/scene/chat_execution/prompt_v2.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/pilot/scene/chat_execution/prompt_v2.py @@ -0,0 +1 @@ + diff --git a/pilot/server/api_v1/api_v1.py b/pilot/server/api_v1/api_v1.py new file mode 100644 index 000000000..19f4e765c --- /dev/null +++ b/pilot/server/api_v1/api_v1.py @@ -0,0 +1,146 @@ +import uuid + +from fastapi import APIRouter, Request, Body, status + +from fastapi.responses import JSONResponse +from fastapi.responses import StreamingResponse +from fastapi.encoders import jsonable_encoder +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from typing import List + +from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo +from pilot.configs.config import Config +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.scene.chat_factory import ChatFactory +from pilot.configs.model_config import (LOGDIR) +from pilot.utils import build_logger +from pilot.scene.base_message import (BaseMessage) + +router = APIRouter() +CFG = Config() +CHAT_FACTORY = ChatFactory() +logger = build_logger("api_v1", LOGDIR + "api_v1.log") + + +async def validation_exception_handler(request: Request, exc: RequestValidationError): + message = "" + for error in exc.errors(): + message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";" + return Result.faild(message) + + +@router.get('/v1/chat/dialogue/list', response_model=Result[List[ConversationVo]]) +async def dialogue_list(user_id: str): + #### TODO + + conversations = [ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]"), + ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]")] + + return Result[ConversationVo].succ(conversations) + + +@router.post('/v1/chat/dialogue/new', response_model=Result[str]) +async def dialogue_new(user_id: str): + unique_id = uuid.uuid1() + return Result.succ(unique_id) + + +@router.post('/v1/chat/dialogue/delete') +async def dialogue_delete(con_uid: str, user_id: str): + #### TODO + return Result.succ(None) + + +@router.post('/v1/chat/completions', response_model=Result[MessageVo]) +async def chat_completions(dialogue: ConversationVo = Body()): + print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") + + if not ChatScene.is_valid_mode(dialogue.chat_mode): + raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")) + + chat_param = { + "chat_session_id": dialogue.conv_uid, + "user_input": dialogue.user_input, + } + + if ChatScene.ChatWithDbExecute == dialogue.chat_mode: + chat_param.update("db_name", dialogue.select_param) + elif ChatScene.ChatWithDbQA == dialogue.chat_mode: + chat_param.update("db_name", dialogue.select_param) + elif ChatScene.ChatExecution == dialogue.chat_mode: + chat_param.update("plugin_selector", dialogue.select_param) + elif ChatScene.ChatNewKnowledge == dialogue.chat_mode: + chat_param.update("knowledge_name", dialogue.select_param) + elif ChatScene.ChatUrlKnowledge == dialogue.chat_mode: + chat_param.update("url", dialogue.select_param) + + chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) + if not chat.prompt_template.stream_out: + return non_stream_response(chat) + else: + return stream_response(chat) + + +def stream_generator(chat): + model_response = chat.stream_call() + for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + chat.current_message.add_ai_message(msg) + messageVos = [message2Vo(element) for element in chat.current_message.messages] + yield Result.succ(messageVos) +def stream_response(chat): + logger.info("stream out start!") + api_response = StreamingResponse(stream_generator(chat), media_type="application/json") + return api_response + +def message2Vo(message:BaseMessage)->MessageVo: + vo:MessageVo = MessageVo() + vo.role = message.type + vo.role = message.content + vo.time_stamp = message.additional_kwargs.time_stamp if message.additional_kwargs["time_stamp"] else 0 + +def non_stream_response(chat): + logger.info("not stream out, wait model response!") + chat.nostream_call() + messageVos = [message2Vo(element) for element in chat.current_message.messages] + return Result.succ(messageVos) + + +@router.get('/v1/db/types', response_model=Result[str]) +async def db_types(): + return Result.succ(["mysql", "duckdb"]) + + +@router.get('/v1/db/list', response_model=Result[str]) +async def db_list(): + db = CFG.local_db + dbs = db.get_database_list() + return Result.succ(dbs) + + +@router.get('/v1/knowledge/list') +async def knowledge_list(): + return ["test1", "test2"] + + +@router.post('/v1/knowledge/add') +async def knowledge_add(): + return ["test1", "test2"] + + +@router.post('/v1/knowledge/delete') +async def knowledge_delete(): + return ["test1", "test2"] + + +@router.get('/v1/knowledge/types') +async def knowledge_types(): + return ["test1", "test2"] + + +@router.get('/v1/knowledge/detail') +async def knowledge_detail(): + return ["test1", "test2"] diff --git a/pilot/server/api_v1/api_view_model.py b/pilot/server/api_v1/api_view_model.py new file mode 100644 index 000000000..938ce22ec --- /dev/null +++ b/pilot/server/api_v1/api_view_model.py @@ -0,0 +1,57 @@ +from pydantic import BaseModel, Field +from typing import TypeVar, Union, List, Generic + +T = TypeVar('T') + + +class Result(Generic[T], BaseModel): + success: bool + err_code: str + err_msg: str + data: List[T] + + @classmethod + def succ(cls, data: List[T]): + return Result(True, None, None, data) + + @classmethod + def faild(cls, msg): + return Result(True, "E000X", msg, None) + + @classmethod + def faild(cls, code, msg): + return Result(True, code, msg, None) + + +class ConversationVo(BaseModel): + """ + dialogue_uid + """ + conv_uid: str = Field(..., description="dialogue uid") + """ + user input + """ + user_input: str + """ + the scene of chat + """ + chat_mode: str = Field(..., description="the scene of chat ") + """ + chat scene select param + """ + select_param: str + + +class MessageVo(BaseModel): + """ + role that sends out the current message + """ + role: str + """ + current message + """ + context: str + """ + time the current message was sent + """ + time_stamp: float diff --git a/pilot/server/webserver_base.py b/pilot/server/webserver_base.py new file mode 100644 index 000000000..0aa2ac3f9 --- /dev/null +++ b/pilot/server/webserver_base.py @@ -0,0 +1,60 @@ +import signal +import os +import threading +import traceback +import sys + +from pilot.summary.db_summary_client import DBSummaryClient +from pilot.commands.command_mange import CommandRegistry +from pilot.configs.config import Config +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, +) +from pilot.common.plugins import scan_plugins, load_native_plugins +from pilot.utils import build_logger + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) + +logger = build_logger("webserver", LOGDIR + "webserver.log") + + +def signal_handler(sig, frame): + print("in order to avoid chroma db atexit problem") + os._exit(0) + + +def async_db_summery(): + client = DBSummaryClient() + thread = threading.Thread(target=client.init_db_summary) + thread.start() + + +def server_init(args): + logger.info(f"args: {args}") + + # init config + cfg = Config() + + load_native_plugins(cfg) + signal.signal(signal.SIGINT, signal_handler) + async_db_summery() + cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) + + # Loader plugins and commands + command_categories = [ + "pilot.commands.built_in.audio_text", + "pilot.commands.built_in.image_gen", + ] + # exclude commands + command_categories = [ + x for x in command_categories if x not in cfg.disabled_command_categories + ] + command_registry = CommandRegistry() + for command_category in command_categories: + command_registry.import_commands(command_category) + + cfg.command_registry = command_registry diff --git a/pilot/source_embedding/EncodeTextLoader.py b/pilot/source_embedding/EncodeTextLoader.py new file mode 100644 index 000000000..2b7344f18 --- /dev/null +++ b/pilot/source_embedding/EncodeTextLoader.py @@ -0,0 +1,26 @@ +from typing import List, Optional +import chardet + +from langchain.docstore.document import Document +from langchain.document_loaders.base import BaseLoader + + +class EncodeTextLoader(BaseLoader): + """Load text files.""" + + def __init__(self, file_path: str, encoding: Optional[str] = None): + """Initialize with file path.""" + self.file_path = file_path + self.encoding = encoding + + def load(self) -> List[Document]: + """Load from file path.""" + with open(self.file_path, "rb") as f: + raw_text = f.read() + result = chardet.detect(raw_text) + if result["encoding"] is None: + text = raw_text.decode("utf-8") + else: + text = raw_text.decode(result["encoding"]) + metadata = {"source": self.file_path} + return [Document(page_content=text, metadata=metadata)] diff --git a/pilot/source_embedding/__init__.py b/pilot/source_embedding/__init__.py new file mode 100644 index 000000000..464ff11b1 --- /dev/null +++ b/pilot/source_embedding/__init__.py @@ -0,0 +1,3 @@ +from pilot.source_embedding.source_embedding import SourceEmbedding, register + +__all__ = ["SourceEmbedding", "register"] diff --git a/pilot/source_embedding/chn_document_splitter.py b/pilot/source_embedding/chn_document_splitter.py new file mode 100644 index 000000000..5bf06ea8c --- /dev/null +++ b/pilot/source_embedding/chn_document_splitter.py @@ -0,0 +1,55 @@ +import re +from typing import List + +from langchain.text_splitter import CharacterTextSplitter + + +class CHNDocumentSplitter(CharacterTextSplitter): + def __init__(self, pdf: bool = False, sentence_size: int = None, **kwargs): + super().__init__(**kwargs) + self.pdf = pdf + self.sentence_size = sentence_size + + def split_text(self, text: str) -> List[str]: + if self.pdf: + text = re.sub(r"\n{3,}", r"\n", text) + text = re.sub("\s", " ", text) + text = re.sub("\n\n", "", text) + + text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text) + text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) + text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) + text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text) + text = text.rstrip() + ls = [i for i in text.split("\n") if i] + for ele in ls: + if len(ele) > self.sentence_size: + ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele) + ele1_ls = ele1.split("\n") + for ele_ele1 in ele1_ls: + if len(ele_ele1) > self.sentence_size: + ele_ele2 = re.sub( + r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r"\1\n\2", ele_ele1 + ) + ele2_ls = ele_ele2.split("\n") + for ele_ele2 in ele2_ls: + if len(ele_ele2) > self.sentence_size: + ele_ele3 = re.sub( + '( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2 + ) + ele2_id = ele2_ls.index(ele_ele2) + ele2_ls = ( + ele2_ls[:ele2_id] + + [i for i in ele_ele3.split("\n") if i] + + ele2_ls[ele2_id + 1 :] + ) + ele_id = ele1_ls.index(ele_ele1) + ele1_ls = ( + ele1_ls[:ele_id] + + [i for i in ele2_ls if i] + + ele1_ls[ele_id + 1 :] + ) + + id = ls.index(ele) + ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1 :] + return ls diff --git a/pilot/source_embedding/csv_embedding.py b/pilot/source_embedding/csv_embedding.py new file mode 100644 index 000000000..0e69574b4 --- /dev/null +++ b/pilot/source_embedding/csv_embedding.py @@ -0,0 +1,36 @@ +from typing import Dict, List, Optional + +from langchain.document_loaders import CSVLoader +from langchain.schema import Document + +from pilot.source_embedding import SourceEmbedding, register + + +class CSVEmbedding(SourceEmbedding): + """csv embedding for read csv document.""" + + def __init__( + self, + file_path, + vector_store_config, + embedding_args: Optional[Dict] = None, + ): + """Initialize with csv path.""" + super().__init__(file_path, vector_store_config) + self.file_path = file_path + self.vector_store_config = vector_store_config + self.embedding_args = embedding_args + + @register + def read(self): + """Load from csv path.""" + loader = CSVLoader(file_path=self.file_path) + return loader.load() + + @register + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + documents[i].page_content = d.page_content.replace("\n", "") + i += 1 + return documents diff --git a/pilot/source_embedding/external/__init__.py b/pilot/source_embedding/external/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py new file mode 100644 index 000000000..97b515897 --- /dev/null +++ b/pilot/source_embedding/knowledge_embedding.py @@ -0,0 +1,89 @@ +from typing import Optional + +from chromadb.errors import NotEnoughElementsException +from langchain.embeddings import HuggingFaceEmbeddings + +from pilot.configs.config import Config +from pilot.source_embedding.csv_embedding import CSVEmbedding +from pilot.source_embedding.markdown_embedding import MarkdownEmbedding +from pilot.source_embedding.pdf_embedding import PDFEmbedding +from pilot.source_embedding.ppt_embedding import PPTEmbedding +from pilot.source_embedding.url_embedding import URLEmbedding +from pilot.source_embedding.word_embedding import WordEmbedding +from pilot.vector_store.connector import VectorStoreConnector + +CFG = Config() + +KnowledgeEmbeddingType = { + ".txt": (MarkdownEmbedding, {}), + ".md": (MarkdownEmbedding, {}), + ".pdf": (PDFEmbedding, {}), + ".doc": (WordEmbedding, {}), + ".docx": (WordEmbedding, {}), + ".csv": (CSVEmbedding, {}), + ".ppt": (PPTEmbedding, {}), + ".pptx": (PPTEmbedding, {}), +} + + +class KnowledgeEmbedding: + def __init__( + self, + model_name, + vector_store_config, + file_type: Optional[str] = "default", + file_path: Optional[str] = None, + ): + """Initialize with Loader url, model_name, vector_store_config""" + self.file_path = file_path + self.model_name = model_name + self.vector_store_config = vector_store_config + self.file_type = file_type + self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) + self.vector_store_config["embeddings"] = self.embeddings + + def knowledge_embedding(self): + self.knowledge_embedding_client = self.init_knowledge_embedding() + self.knowledge_embedding_client.source_embedding() + + def knowledge_embedding_batch(self, docs): + # docs = self.knowledge_embedding_client.read_batch() + self.knowledge_embedding_client.index_to_store(docs) + + def read(self): + return self.knowledge_embedding_client.read_batch() + + def init_knowledge_embedding(self): + if self.file_type == "url": + embedding = URLEmbedding( + file_path=self.file_path, + vector_store_config=self.vector_store_config, + ) + return embedding + extension = "." + self.file_path.rsplit(".", 1)[-1] + if extension in KnowledgeEmbeddingType: + knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension] + embedding = knowledge_class( + self.file_path, + vector_store_config=self.vector_store_config, + **knowledge_args, + ) + return embedding + raise ValueError(f"Unsupported knowledge file type '{extension}'") + return embedding + + def similar_search(self, text, topk): + vector_client = VectorStoreConnector( + CFG.VECTOR_STORE_TYPE, self.vector_store_config + ) + try: + ans = vector_client.similar_search(text, topk) + except NotEnoughElementsException: + ans = vector_client.similar_search(text, 1) + return ans + + def vector_exist(self): + vector_client = VectorStoreConnector( + CFG.VECTOR_STORE_TYPE, self.vector_store_config + ) + return vector_client.vector_name_exists() diff --git a/pilot/source_embedding/markdown_embedding.py b/pilot/source_embedding/markdown_embedding.py new file mode 100644 index 000000000..d8caee959 --- /dev/null +++ b/pilot/source_embedding/markdown_embedding.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import os +from typing import List + +import markdown +from bs4 import BeautifulSoup +from langchain.schema import Document +from langchain.text_splitter import SpacyTextSplitter + +from pilot.configs.config import Config +from pilot.source_embedding import SourceEmbedding, register +from pilot.source_embedding.EncodeTextLoader import EncodeTextLoader +from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter + +CFG = Config() + + +class MarkdownEmbedding(SourceEmbedding): + """markdown embedding for read markdown document.""" + + def __init__(self, file_path, vector_store_config): + """Initialize with markdown path.""" + super().__init__(file_path, vector_store_config) + self.file_path = file_path + self.vector_store_config = vector_store_config + # self.encoding = encoding + + @register + def read(self): + """Load from markdown path.""" + loader = EncodeTextLoader(self.file_path) + textsplitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=100, + ) + return loader.load_and_split(textsplitter) + + @register + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + content = markdown.markdown(d.page_content) + soup = BeautifulSoup(content, "html.parser") + for tag in soup(["!doctype", "meta", "i.fa"]): + tag.extract() + documents[i].page_content = soup.get_text() + documents[i].page_content = documents[i].page_content.replace("\n", " ") + i += 1 + return documents diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py new file mode 100644 index 000000000..dd8c39c03 --- /dev/null +++ b/pilot/source_embedding/pdf_embedding.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from typing import List + +from langchain.document_loaders import PyPDFLoader +from langchain.schema import Document +from langchain.text_splitter import SpacyTextSplitter + +from pilot.configs.config import Config +from pilot.source_embedding import SourceEmbedding, register + +CFG = Config() + + +class PDFEmbedding(SourceEmbedding): + """pdf embedding for read pdf document.""" + + def __init__(self, file_path, vector_store_config): + """Initialize with pdf path.""" + super().__init__(file_path, vector_store_config) + self.file_path = file_path + self.vector_store_config = vector_store_config + + @register + def read(self): + """Load from pdf path.""" + loader = PyPDFLoader(self.file_path) + # textsplitter = CHNDocumentSplitter( + # pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE + # ) + textsplitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=100, + ) + return loader.load_and_split(textsplitter) + + @register + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + documents[i].page_content = d.page_content.replace("\n", "") + i += 1 + return documents diff --git a/pilot/source_embedding/pdf_loader.py b/pilot/source_embedding/pdf_loader.py new file mode 100644 index 000000000..bbeead0cd --- /dev/null +++ b/pilot/source_embedding/pdf_loader.py @@ -0,0 +1,55 @@ +"""Loader that loads image files.""" +import os +from typing import List + +import fitz +from langchain.document_loaders.unstructured import UnstructuredFileLoader +from paddleocr import PaddleOCR + + +class UnstructuredPaddlePDFLoader(UnstructuredFileLoader): + """Loader that uses unstructured to load image files, such as PNGs and JPGs.""" + + def _get_elements(self) -> List: + def pdf_ocr_txt(filepath, dir_path="tmp_files"): + full_dir_path = os.path.join(os.path.dirname(filepath), dir_path) + if not os.path.exists(full_dir_path): + os.makedirs(full_dir_path) + filename = os.path.split(filepath)[-1] + ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False) + doc = fitz.open(filepath) + txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename)) + img_name = os.path.join(full_dir_path, ".tmp.png") + with open(txt_file_path, "w", encoding="utf-8") as fout: + for i in range(doc.page_count): + page = doc[i] + text = page.get_text("") + fout.write(text) + fout.write("\n") + + img_list = page.get_images() + for img in img_list: + pix = fitz.Pixmap(doc, img[0]) + + pix.save(img_name) + + result = ocr.ocr(img_name) + ocr_result = [i[1][0] for line in result for i in line] + fout.write("\n".join(ocr_result)) + os.remove(img_name) + return txt_file_path + + txt_file_path = pdf_ocr_txt(self.file_path) + from unstructured.partition.text import partition_text + + return partition_text(filename=txt_file_path, **self.unstructured_kwargs) + + +if __name__ == "__main__": + filepath = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test_py.pdf" + ) + loader = UnstructuredPaddlePDFLoader(filepath, mode="elements") + docs = loader.load() + for doc in docs: + print(doc) diff --git a/pilot/source_embedding/ppt_embedding.py b/pilot/source_embedding/ppt_embedding.py new file mode 100644 index 000000000..583b29ed1 --- /dev/null +++ b/pilot/source_embedding/ppt_embedding.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from typing import List + +from langchain.document_loaders import UnstructuredPowerPointLoader +from langchain.schema import Document +from langchain.text_splitter import SpacyTextSplitter + +from pilot.configs.config import Config +from pilot.source_embedding import SourceEmbedding, register + +CFG = Config() + + +class PPTEmbedding(SourceEmbedding): + """ppt embedding for read ppt document.""" + + def __init__(self, file_path, vector_store_config): + """Initialize with pdf path.""" + super().__init__(file_path, vector_store_config) + self.file_path = file_path + self.vector_store_config = vector_store_config + + @register + def read(self): + """Load from ppt path.""" + loader = UnstructuredPowerPointLoader(self.file_path) + textsplitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=200, + ) + return loader.load_and_split(textsplitter) + + @register + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + documents[i].page_content = d.page_content.replace("\n", "") + i += 1 + return documents diff --git a/pilot/source_embedding/search_milvus.py b/pilot/source_embedding/search_milvus.py new file mode 100644 index 000000000..aa02c1f61 --- /dev/null +++ b/pilot/source_embedding/search_milvus.py @@ -0,0 +1,61 @@ +# from langchain.embeddings import HuggingFaceEmbeddings +# from langchain.vectorstores import Milvus +# from pymilvus import Collection,utility +# from pymilvus import connections, DataType, FieldSchema, CollectionSchema +# +# # milvus = connections.connect( +# # alias="default", +# # host='localhost', +# # port="19530" +# # ) +# # collection = Collection("book") +# +# +# # Get an existing collection. +# # collection.load() +# # +# # search_params = {"metric_type": "L2", "params": {}, "offset": 5} +# # +# # results = collection.search( +# # data=[[0.1, 0.2]], +# # anns_field="book_intro", +# # param=search_params, +# # limit=10, +# # expr=None, +# # output_fields=['book_id'], +# # consistency_level="Strong" +# # ) +# # +# # # get the IDs of all returned hits +# # results[0].ids +# # +# # # get the distances to the query vector from all returned hits +# # results[0].distances +# # +# # # get the value of an output field specified in the search request. +# # # vector fields are not supported yet. +# # hit = results[0][0] +# # hit.entity.get('title') +# +# # milvus = connections.connect( +# # alias="default", +# # host='localhost', +# # port="19530" +# # ) +# from pilot.vector_store.milvus_store import MilvusStore +# +# data = ["aaa", "bbb"] +# model_name = "xx/all-MiniLM-L6-v2" +# embeddings = HuggingFaceEmbeddings(model_name=model_name) +# +# # text_embeddings = Text2Vectors() +# mivuls = MilvusStore(cfg={"url": "127.0.0.1", "port": "19530", "alias": "default", "table_name": "test_k"}) +# +# mivuls.insert(["textc","tezt2"]) +# print("success") +# ct +# # mivuls.from_texts(texts=data, embedding=embeddings) +# # docs, +# # embedding=embeddings, +# # connection_args={"host": "127.0.0.1", "port": "19530", "alias": "default"} +# # ) diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py new file mode 100644 index 000000000..3d881fcdf --- /dev/null +++ b/pilot/source_embedding/source_embedding.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + +from chromadb.errors import NotEnoughElementsException +from pilot.configs.config import Config +from pilot.vector_store.connector import VectorStoreConnector + +registered_methods = [] +CFG = Config() + + +def register(method): + registered_methods.append(method.__name__) + return method + + +class SourceEmbedding(ABC): + """base class for read data source embedding pipeline. + include data read, data process, data split, data to vector, data index vector store + Implementations should implement the method + """ + + def __init__( + self, + file_path, + vector_store_config, + embedding_args: Optional[Dict] = None, + ): + """Initialize with Loader url, model_name, vector_store_config""" + self.file_path = file_path + self.vector_store_config = vector_store_config + self.embedding_args = embedding_args + self.embeddings = vector_store_config["embeddings"] + self.vector_client = VectorStoreConnector( + CFG.VECTOR_STORE_TYPE, vector_store_config + ) + + @abstractmethod + @register + def read(self) -> List[ABC]: + """read datasource into document objects.""" + + @register + def data_process(self, text): + """pre process data.""" + + @register + def text_split(self, text): + """text split chunk""" + pass + + @register + def text_to_vector(self, docs): + """transform vector""" + pass + + @register + def index_to_store(self, docs): + """index to vector store""" + self.vector_client.load_document(docs) + + @register + def similar_search(self, doc, topk): + """vector store similarity_search""" + try: + ans = self.vector_client.similar_search(doc, topk) + except NotEnoughElementsException: + ans = self.vector_client.similar_search(doc, 1) + return ans + + def vector_name_exist(self): + return self.vector_client.vector_name_exists() + + def source_embedding(self): + if "read" in registered_methods: + text = self.read() + if "data_process" in registered_methods: + text = self.data_process(text) + if "text_split" in registered_methods: + self.text_split(text) + if "text_to_vector" in registered_methods: + self.text_to_vector(text) + if "index_to_store" in registered_methods: + self.index_to_store(text) + + def read_batch(self): + if "read" in registered_methods: + text = self.read() + if "data_process" in registered_methods: + text = self.data_process(text) + if "text_split" in registered_methods: + self.text_split(text) + return text diff --git a/pilot/source_embedding/string_embedding.py b/pilot/source_embedding/string_embedding.py new file mode 100644 index 000000000..a1d18ee82 --- /dev/null +++ b/pilot/source_embedding/string_embedding.py @@ -0,0 +1,29 @@ +from typing import List + +from langchain.schema import Document + +from pilot import SourceEmbedding, register + + +class StringEmbedding(SourceEmbedding): + """string embedding for read string document.""" + + def __init__(self, file_path, vector_store_config): + """Initialize with pdf path.""" + super().__init__(file_path, vector_store_config) + self.file_path = file_path + self.vector_store_config = vector_store_config + + @register + def read(self): + """Load from String path.""" + metadata = {"source": "db_summary"} + return [Document(page_content=self.file_path, metadata=metadata)] + + @register + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + documents[i].page_content = d.page_content.replace("\n", "") + i += 1 + return documents diff --git a/pilot/source_embedding/url_embedding.py b/pilot/source_embedding/url_embedding.py new file mode 100644 index 000000000..a315e6e45 --- /dev/null +++ b/pilot/source_embedding/url_embedding.py @@ -0,0 +1,49 @@ +from typing import List + +from bs4 import BeautifulSoup +from langchain.document_loaders import WebBaseLoader +from langchain.schema import Document +from langchain.text_splitter import CharacterTextSplitter + +from pilot.configs.config import Config +from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE +from pilot.source_embedding import SourceEmbedding, register +from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter + +CFG = Config() + + +class URLEmbedding(SourceEmbedding): + """url embedding for read url document.""" + + def __init__(self, file_path, vector_store_config): + """Initialize with url path.""" + super().__init__(file_path, vector_store_config) + self.file_path = file_path + self.vector_store_config = vector_store_config + + @register + def read(self): + """Load from url path.""" + loader = WebBaseLoader(web_path=self.file_path) + if CFG.LANGUAGE == "en": + text_splitter = CharacterTextSplitter( + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, + chunk_overlap=20, + length_function=len, + ) + else: + text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000) + return loader.load_and_split(text_splitter) + + @register + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + content = d.page_content.replace("\n", "") + soup = BeautifulSoup(content, "html.parser") + for tag in soup(["!doctype", "meta"]): + tag.extract() + documents[i].page_content = soup.get_text() + i += 1 + return documents diff --git a/pilot/source_embedding/word_embedding.py b/pilot/source_embedding/word_embedding.py new file mode 100644 index 000000000..1f30f241c --- /dev/null +++ b/pilot/source_embedding/word_embedding.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from typing import List + +from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader +from langchain.schema import Document + +from pilot.configs.config import Config +from pilot.source_embedding import SourceEmbedding, register +from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter + +CFG = Config() + + +class WordEmbedding(SourceEmbedding): + """word embedding for read word document.""" + + def __init__(self, file_path, vector_store_config): + """Initialize with word path.""" + super().__init__(file_path, vector_store_config) + self.file_path = file_path + self.vector_store_config = vector_store_config + + @register + def read(self): + """Load from word path.""" + loader = UnstructuredWordDocumentLoader(self.file_path) + textsplitter = CHNDocumentSplitter( + pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE + ) + return loader.load_and_split(textsplitter) + + @register + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + documents[i].page_content = d.page_content.replace("\n", "") + i += 1 + return documents