mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-07 03:14:42 +00:00
WEB API independent
This commit is contained in:
parent
04b4f6adf9
commit
c57c2e60e8
12
pilot/connections/rdbms/py_study/test_cls_1.py
Normal file
12
pilot/connections/rdbms/py_study/test_cls_1.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from test_cls_base import TestBase
|
||||||
|
|
||||||
|
|
||||||
|
class Test1(TestBase):
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
self.test_values.append("x")
|
||||||
|
self.test_values.append("y")
|
||||||
|
self.test_values.append("g")
|
||||||
|
|
15
pilot/connections/rdbms/py_study/test_cls_2.py
Normal file
15
pilot/connections/rdbms/py_study/test_cls_2.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from test_cls_base import TestBase
|
||||||
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||||
|
|
||||||
|
class Test2(TestBase):
|
||||||
|
test_2_values:List = []
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
self.test_values.append(1)
|
||||||
|
self.test_values.append(2)
|
||||||
|
self.test_values.append(3)
|
||||||
|
self.test_2_values.append("x")
|
||||||
|
self.test_2_values.append("y")
|
||||||
|
self.test_2_values.append("z")
|
12
pilot/connections/rdbms/py_study/test_cls_base.py
Normal file
12
pilot/connections/rdbms/py_study/test_cls_base.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||||
|
|
||||||
|
|
||||||
|
class TestBase(BaseModel, ABC):
|
||||||
|
test_values: List = []
|
||||||
|
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
print(self.__class__.__name__ + ":" )
|
||||||
|
print(self.test_values)
|
38
pilot/prompts/example_base.py
Normal file
38
pilot/prompts/example_base.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||||
|
|
||||||
|
from pilot.common.schema import ExampleType
|
||||||
|
|
||||||
|
class ExampleSelector(BaseModel, ABC):
|
||||||
|
examples: List[List]
|
||||||
|
use_example: bool = False
|
||||||
|
type: str = ExampleType.ONE_SHOT.value
|
||||||
|
|
||||||
|
def examples(self, count: int = 2):
|
||||||
|
if ExampleType.ONE_SHOT.value == self.type:
|
||||||
|
return self.__one_show_context()
|
||||||
|
else:
|
||||||
|
return self.__few_shot_context(count)
|
||||||
|
|
||||||
|
def __few_shot_context(self, count: int = 2) -> List[List]:
|
||||||
|
"""
|
||||||
|
Use 2 or more examples, default 2
|
||||||
|
Returns: example text
|
||||||
|
"""
|
||||||
|
if self.use_example:
|
||||||
|
need_use = self.examples[:count]
|
||||||
|
return need_use
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __one_show_context(self) -> List:
|
||||||
|
"""
|
||||||
|
Use one examples
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.use_example:
|
||||||
|
need_use = self.examples[:1]
|
||||||
|
return need_use
|
||||||
|
|
||||||
|
return None
|
9
pilot/scene/chat_execution/example.py
Normal file
9
pilot/scene/chat_execution/example.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from pilot.prompts.example_base import ExampleSelector
|
||||||
|
|
||||||
|
## Two examples are defined by default
|
||||||
|
EXAMPLES = [
|
||||||
|
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}],
|
||||||
|
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}]
|
||||||
|
]
|
||||||
|
|
||||||
|
example = ExampleSelector(examples=EXAMPLES, use_example=True)
|
1
pilot/scene/chat_execution/prompt_v2.py
Normal file
1
pilot/scene/chat_execution/prompt_v2.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
146
pilot/server/api_v1/api_v1.py
Normal file
146
pilot/server/api_v1/api_v1.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request, Body, status
|
||||||
|
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base_chat import BaseChat
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
|
from pilot.configs.model_config import (LOGDIR)
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
from pilot.scene.base_message import (BaseMessage)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
CFG = Config()
|
||||||
|
CHAT_FACTORY = ChatFactory()
|
||||||
|
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
||||||
|
|
||||||
|
|
||||||
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
message = ""
|
||||||
|
for error in exc.errors():
|
||||||
|
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
|
||||||
|
return Result.faild(message)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/chat/dialogue/list', response_model=Result[List[ConversationVo]])
|
||||||
|
async def dialogue_list(user_id: str):
|
||||||
|
#### TODO
|
||||||
|
|
||||||
|
conversations = [ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]"),
|
||||||
|
ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]")]
|
||||||
|
|
||||||
|
return Result[ConversationVo].succ(conversations)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/chat/dialogue/new', response_model=Result[str])
|
||||||
|
async def dialogue_new(user_id: str):
|
||||||
|
unique_id = uuid.uuid1()
|
||||||
|
return Result.succ(unique_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/chat/dialogue/delete')
|
||||||
|
async def dialogue_delete(con_uid: str, user_id: str):
|
||||||
|
#### TODO
|
||||||
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/chat/completions', response_model=Result[MessageVo])
|
||||||
|
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||||
|
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||||
|
|
||||||
|
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||||
|
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
|
||||||
|
|
||||||
|
chat_param = {
|
||||||
|
"chat_session_id": dialogue.conv_uid,
|
||||||
|
"user_input": dialogue.user_input,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ChatScene.ChatWithDbExecute == dialogue.chat_mode:
|
||||||
|
chat_param.update("db_name", dialogue.select_param)
|
||||||
|
elif ChatScene.ChatWithDbQA == dialogue.chat_mode:
|
||||||
|
chat_param.update("db_name", dialogue.select_param)
|
||||||
|
elif ChatScene.ChatExecution == dialogue.chat_mode:
|
||||||
|
chat_param.update("plugin_selector", dialogue.select_param)
|
||||||
|
elif ChatScene.ChatNewKnowledge == dialogue.chat_mode:
|
||||||
|
chat_param.update("knowledge_name", dialogue.select_param)
|
||||||
|
elif ChatScene.ChatUrlKnowledge == dialogue.chat_mode:
|
||||||
|
chat_param.update("url", dialogue.select_param)
|
||||||
|
|
||||||
|
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||||
|
if not chat.prompt_template.stream_out:
|
||||||
|
return non_stream_response(chat)
|
||||||
|
else:
|
||||||
|
return stream_response(chat)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_generator(chat):
|
||||||
|
model_response = chat.stream_call()
|
||||||
|
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
|
if chunk:
|
||||||
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||||
|
chat.current_message.add_ai_message(msg)
|
||||||
|
messageVos = [message2Vo(element) for element in chat.current_message.messages]
|
||||||
|
yield Result.succ(messageVos)
|
||||||
|
def stream_response(chat):
|
||||||
|
logger.info("stream out start!")
|
||||||
|
api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
|
||||||
|
return api_response
|
||||||
|
|
||||||
|
def message2Vo(message:BaseMessage)->MessageVo:
|
||||||
|
vo:MessageVo = MessageVo()
|
||||||
|
vo.role = message.type
|
||||||
|
vo.role = message.content
|
||||||
|
vo.time_stamp = message.additional_kwargs.time_stamp if message.additional_kwargs["time_stamp"] else 0
|
||||||
|
|
||||||
|
def non_stream_response(chat):
|
||||||
|
logger.info("not stream out, wait model response!")
|
||||||
|
chat.nostream_call()
|
||||||
|
messageVos = [message2Vo(element) for element in chat.current_message.messages]
|
||||||
|
return Result.succ(messageVos)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/db/types', response_model=Result[str])
|
||||||
|
async def db_types():
|
||||||
|
return Result.succ(["mysql", "duckdb"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/db/list', response_model=Result[str])
|
||||||
|
async def db_list():
|
||||||
|
db = CFG.local_db
|
||||||
|
dbs = db.get_database_list()
|
||||||
|
return Result.succ(dbs)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/knowledge/list')
|
||||||
|
async def knowledge_list():
|
||||||
|
return ["test1", "test2"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/knowledge/add')
|
||||||
|
async def knowledge_add():
|
||||||
|
return ["test1", "test2"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/knowledge/delete')
|
||||||
|
async def knowledge_delete():
|
||||||
|
return ["test1", "test2"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/knowledge/types')
|
||||||
|
async def knowledge_types():
|
||||||
|
return ["test1", "test2"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/knowledge/detail')
|
||||||
|
async def knowledge_detail():
|
||||||
|
return ["test1", "test2"]
|
57
pilot/server/api_v1/api_view_model.py
Normal file
57
pilot/server/api_v1/api_view_model.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import TypeVar, Union, List, Generic
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
class Result(Generic[T], BaseModel):
|
||||||
|
success: bool
|
||||||
|
err_code: str
|
||||||
|
err_msg: str
|
||||||
|
data: List[T]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def succ(cls, data: List[T]):
|
||||||
|
return Result(True, None, None, data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def faild(cls, msg):
|
||||||
|
return Result(True, "E000X", msg, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def faild(cls, code, msg):
|
||||||
|
return Result(True, code, msg, None)
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVo(BaseModel):
|
||||||
|
"""
|
||||||
|
dialogue_uid
|
||||||
|
"""
|
||||||
|
conv_uid: str = Field(..., description="dialogue uid")
|
||||||
|
"""
|
||||||
|
user input
|
||||||
|
"""
|
||||||
|
user_input: str
|
||||||
|
"""
|
||||||
|
the scene of chat
|
||||||
|
"""
|
||||||
|
chat_mode: str = Field(..., description="the scene of chat ")
|
||||||
|
"""
|
||||||
|
chat scene select param
|
||||||
|
"""
|
||||||
|
select_param: str
|
||||||
|
|
||||||
|
|
||||||
|
class MessageVo(BaseModel):
|
||||||
|
"""
|
||||||
|
role that sends out the current message
|
||||||
|
"""
|
||||||
|
role: str
|
||||||
|
"""
|
||||||
|
current message
|
||||||
|
"""
|
||||||
|
context: str
|
||||||
|
"""
|
||||||
|
time the current message was sent
|
||||||
|
"""
|
||||||
|
time_stamp: float
|
60
pilot/server/webserver_base.py
Normal file
60
pilot/server/webserver_base.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import signal
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
|
from pilot.commands.command_mange import CommandRegistry
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.configs.model_config import (
|
||||||
|
DATASETS_DIR,
|
||||||
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
LLM_MODEL_CONFIG,
|
||||||
|
LOGDIR,
|
||||||
|
)
|
||||||
|
from pilot.common.plugins import scan_plugins, load_native_plugins
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
|
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||||
|
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
print("in order to avoid chroma db atexit problem")
|
||||||
|
os._exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def async_db_summery():
|
||||||
|
client = DBSummaryClient()
|
||||||
|
thread = threading.Thread(target=client.init_db_summary)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
|
||||||
|
def server_init(args):
|
||||||
|
logger.info(f"args: {args}")
|
||||||
|
|
||||||
|
# init config
|
||||||
|
cfg = Config()
|
||||||
|
|
||||||
|
load_native_plugins(cfg)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
async_db_summery()
|
||||||
|
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||||
|
|
||||||
|
# Loader plugins and commands
|
||||||
|
command_categories = [
|
||||||
|
"pilot.commands.built_in.audio_text",
|
||||||
|
"pilot.commands.built_in.image_gen",
|
||||||
|
]
|
||||||
|
# exclude commands
|
||||||
|
command_categories = [
|
||||||
|
x for x in command_categories if x not in cfg.disabled_command_categories
|
||||||
|
]
|
||||||
|
command_registry = CommandRegistry()
|
||||||
|
for command_category in command_categories:
|
||||||
|
command_registry.import_commands(command_category)
|
||||||
|
|
||||||
|
cfg.command_registry = command_registry
|
26
pilot/source_embedding/EncodeTextLoader.py
Normal file
26
pilot/source_embedding/EncodeTextLoader.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
import chardet
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.document_loaders.base import BaseLoader
|
||||||
|
|
||||||
|
|
||||||
|
class EncodeTextLoader(BaseLoader):
|
||||||
|
"""Load text files."""
|
||||||
|
|
||||||
|
def __init__(self, file_path: str, encoding: Optional[str] = None):
|
||||||
|
"""Initialize with file path."""
|
||||||
|
self.file_path = file_path
|
||||||
|
self.encoding = encoding
|
||||||
|
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
"""Load from file path."""
|
||||||
|
with open(self.file_path, "rb") as f:
|
||||||
|
raw_text = f.read()
|
||||||
|
result = chardet.detect(raw_text)
|
||||||
|
if result["encoding"] is None:
|
||||||
|
text = raw_text.decode("utf-8")
|
||||||
|
else:
|
||||||
|
text = raw_text.decode(result["encoding"])
|
||||||
|
metadata = {"source": self.file_path}
|
||||||
|
return [Document(page_content=text, metadata=metadata)]
|
3
pilot/source_embedding/__init__.py
Normal file
3
pilot/source_embedding/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from pilot.source_embedding.source_embedding import SourceEmbedding, register
|
||||||
|
|
||||||
|
__all__ = ["SourceEmbedding", "register"]
|
55
pilot/source_embedding/chn_document_splitter.py
Normal file
55
pilot/source_embedding/chn_document_splitter.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
|
||||||
|
|
||||||
|
class CHNDocumentSplitter(CharacterTextSplitter):
|
||||||
|
def __init__(self, pdf: bool = False, sentence_size: int = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.pdf = pdf
|
||||||
|
self.sentence_size = sentence_size
|
||||||
|
|
||||||
|
def split_text(self, text: str) -> List[str]:
|
||||||
|
if self.pdf:
|
||||||
|
text = re.sub(r"\n{3,}", r"\n", text)
|
||||||
|
text = re.sub("\s", " ", text)
|
||||||
|
text = re.sub("\n\n", "", text)
|
||||||
|
|
||||||
|
text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text)
|
||||||
|
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)
|
||||||
|
text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text)
|
||||||
|
text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text)
|
||||||
|
text = text.rstrip()
|
||||||
|
ls = [i for i in text.split("\n") if i]
|
||||||
|
for ele in ls:
|
||||||
|
if len(ele) > self.sentence_size:
|
||||||
|
ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele)
|
||||||
|
ele1_ls = ele1.split("\n")
|
||||||
|
for ele_ele1 in ele1_ls:
|
||||||
|
if len(ele_ele1) > self.sentence_size:
|
||||||
|
ele_ele2 = re.sub(
|
||||||
|
r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r"\1\n\2", ele_ele1
|
||||||
|
)
|
||||||
|
ele2_ls = ele_ele2.split("\n")
|
||||||
|
for ele_ele2 in ele2_ls:
|
||||||
|
if len(ele_ele2) > self.sentence_size:
|
||||||
|
ele_ele3 = re.sub(
|
||||||
|
'( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2
|
||||||
|
)
|
||||||
|
ele2_id = ele2_ls.index(ele_ele2)
|
||||||
|
ele2_ls = (
|
||||||
|
ele2_ls[:ele2_id]
|
||||||
|
+ [i for i in ele_ele3.split("\n") if i]
|
||||||
|
+ ele2_ls[ele2_id + 1 :]
|
||||||
|
)
|
||||||
|
ele_id = ele1_ls.index(ele_ele1)
|
||||||
|
ele1_ls = (
|
||||||
|
ele1_ls[:ele_id]
|
||||||
|
+ [i for i in ele2_ls if i]
|
||||||
|
+ ele1_ls[ele_id + 1 :]
|
||||||
|
)
|
||||||
|
|
||||||
|
id = ls.index(ele)
|
||||||
|
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1 :]
|
||||||
|
return ls
|
36
pilot/source_embedding/csv_embedding.py
Normal file
36
pilot/source_embedding/csv_embedding.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.document_loaders import CSVLoader
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
from pilot.source_embedding import SourceEmbedding, register
|
||||||
|
|
||||||
|
|
||||||
|
class CSVEmbedding(SourceEmbedding):
|
||||||
|
"""csv embedding for read csv document."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path,
|
||||||
|
vector_store_config,
|
||||||
|
embedding_args: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
"""Initialize with csv path."""
|
||||||
|
super().__init__(file_path, vector_store_config)
|
||||||
|
self.file_path = file_path
|
||||||
|
self.vector_store_config = vector_store_config
|
||||||
|
self.embedding_args = embedding_args
|
||||||
|
|
||||||
|
@register
|
||||||
|
def read(self):
|
||||||
|
"""Load from csv path."""
|
||||||
|
loader = CSVLoader(file_path=self.file_path)
|
||||||
|
return loader.load()
|
||||||
|
|
||||||
|
@register
|
||||||
|
def data_process(self, documents: List[Document]):
|
||||||
|
i = 0
|
||||||
|
for d in documents:
|
||||||
|
documents[i].page_content = d.page_content.replace("\n", "")
|
||||||
|
i += 1
|
||||||
|
return documents
|
0
pilot/source_embedding/external/__init__.py
vendored
Normal file
0
pilot/source_embedding/external/__init__.py
vendored
Normal file
89
pilot/source_embedding/knowledge_embedding.py
Normal file
89
pilot/source_embedding/knowledge_embedding.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from chromadb.errors import NotEnoughElementsException
|
||||||
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
||||||
|
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
||||||
|
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
||||||
|
from pilot.source_embedding.ppt_embedding import PPTEmbedding
|
||||||
|
from pilot.source_embedding.url_embedding import URLEmbedding
|
||||||
|
from pilot.source_embedding.word_embedding import WordEmbedding
|
||||||
|
from pilot.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
KnowledgeEmbeddingType = {
|
||||||
|
".txt": (MarkdownEmbedding, {}),
|
||||||
|
".md": (MarkdownEmbedding, {}),
|
||||||
|
".pdf": (PDFEmbedding, {}),
|
||||||
|
".doc": (WordEmbedding, {}),
|
||||||
|
".docx": (WordEmbedding, {}),
|
||||||
|
".csv": (CSVEmbedding, {}),
|
||||||
|
".ppt": (PPTEmbedding, {}),
|
||||||
|
".pptx": (PPTEmbedding, {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeEmbedding:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name,
|
||||||
|
vector_store_config,
|
||||||
|
file_type: Optional[str] = "default",
|
||||||
|
file_path: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||||
|
self.file_path = file_path
|
||||||
|
self.model_name = model_name
|
||||||
|
self.vector_store_config = vector_store_config
|
||||||
|
self.file_type = file_type
|
||||||
|
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||||
|
self.vector_store_config["embeddings"] = self.embeddings
|
||||||
|
|
||||||
|
def knowledge_embedding(self):
|
||||||
|
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||||
|
self.knowledge_embedding_client.source_embedding()
|
||||||
|
|
||||||
|
def knowledge_embedding_batch(self, docs):
|
||||||
|
# docs = self.knowledge_embedding_client.read_batch()
|
||||||
|
self.knowledge_embedding_client.index_to_store(docs)
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
return self.knowledge_embedding_client.read_batch()
|
||||||
|
|
||||||
|
def init_knowledge_embedding(self):
|
||||||
|
if self.file_type == "url":
|
||||||
|
embedding = URLEmbedding(
|
||||||
|
file_path=self.file_path,
|
||||||
|
vector_store_config=self.vector_store_config,
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
extension = "." + self.file_path.rsplit(".", 1)[-1]
|
||||||
|
if extension in KnowledgeEmbeddingType:
|
||||||
|
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
|
||||||
|
embedding = knowledge_class(
|
||||||
|
self.file_path,
|
||||||
|
vector_store_config=self.vector_store_config,
|
||||||
|
**knowledge_args,
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
raise ValueError(f"Unsupported knowledge file type '{extension}'")
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def similar_search(self, text, topk):
|
||||||
|
vector_client = VectorStoreConnector(
|
||||||
|
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
ans = vector_client.similar_search(text, topk)
|
||||||
|
except NotEnoughElementsException:
|
||||||
|
ans = vector_client.similar_search(text, 1)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def vector_exist(self):
|
||||||
|
vector_client = VectorStoreConnector(
|
||||||
|
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||||
|
)
|
||||||
|
return vector_client.vector_name_exists()
|
51
pilot/source_embedding/markdown_embedding.py
Normal file
51
pilot/source_embedding/markdown_embedding.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import markdown
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.text_splitter import SpacyTextSplitter
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.source_embedding import SourceEmbedding, register
|
||||||
|
from pilot.source_embedding.EncodeTextLoader import EncodeTextLoader
|
||||||
|
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class MarkdownEmbedding(SourceEmbedding):
|
||||||
|
"""markdown embedding for read markdown document."""
|
||||||
|
|
||||||
|
def __init__(self, file_path, vector_store_config):
|
||||||
|
"""Initialize with markdown path."""
|
||||||
|
super().__init__(file_path, vector_store_config)
|
||||||
|
self.file_path = file_path
|
||||||
|
self.vector_store_config = vector_store_config
|
||||||
|
# self.encoding = encoding
|
||||||
|
|
||||||
|
@register
|
||||||
|
def read(self):
|
||||||
|
"""Load from markdown path."""
|
||||||
|
loader = EncodeTextLoader(self.file_path)
|
||||||
|
textsplitter = SpacyTextSplitter(
|
||||||
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||||
|
chunk_overlap=100,
|
||||||
|
)
|
||||||
|
return loader.load_and_split(textsplitter)
|
||||||
|
|
||||||
|
@register
|
||||||
|
def data_process(self, documents: List[Document]):
|
||||||
|
i = 0
|
||||||
|
for d in documents:
|
||||||
|
content = markdown.markdown(d.page_content)
|
||||||
|
soup = BeautifulSoup(content, "html.parser")
|
||||||
|
for tag in soup(["!doctype", "meta", "i.fa"]):
|
||||||
|
tag.extract()
|
||||||
|
documents[i].page_content = soup.get_text()
|
||||||
|
documents[i].page_content = documents[i].page_content.replace("\n", " ")
|
||||||
|
i += 1
|
||||||
|
return documents
|
44
pilot/source_embedding/pdf_embedding.py
Normal file
44
pilot/source_embedding/pdf_embedding.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.document_loaders import PyPDFLoader
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.text_splitter import SpacyTextSplitter
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.source_embedding import SourceEmbedding, register
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class PDFEmbedding(SourceEmbedding):
|
||||||
|
"""pdf embedding for read pdf document."""
|
||||||
|
|
||||||
|
def __init__(self, file_path, vector_store_config):
|
||||||
|
"""Initialize with pdf path."""
|
||||||
|
super().__init__(file_path, vector_store_config)
|
||||||
|
self.file_path = file_path
|
||||||
|
self.vector_store_config = vector_store_config
|
||||||
|
|
||||||
|
@register
|
||||||
|
def read(self):
|
||||||
|
"""Load from pdf path."""
|
||||||
|
loader = PyPDFLoader(self.file_path)
|
||||||
|
# textsplitter = CHNDocumentSplitter(
|
||||||
|
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||||
|
# )
|
||||||
|
textsplitter = SpacyTextSplitter(
|
||||||
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||||
|
chunk_overlap=100,
|
||||||
|
)
|
||||||
|
return loader.load_and_split(textsplitter)
|
||||||
|
|
||||||
|
@register
|
||||||
|
def data_process(self, documents: List[Document]):
|
||||||
|
i = 0
|
||||||
|
for d in documents:
|
||||||
|
documents[i].page_content = d.page_content.replace("\n", "")
|
||||||
|
i += 1
|
||||||
|
return documents
|
55
pilot/source_embedding/pdf_loader.py
Normal file
55
pilot/source_embedding/pdf_loader.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""Loader that loads image files."""
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import fitz
|
||||||
|
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||||
|
from paddleocr import PaddleOCR
|
||||||
|
|
||||||
|
|
||||||
|
class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
|
||||||
|
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
|
||||||
|
|
||||||
|
def _get_elements(self) -> List:
|
||||||
|
def pdf_ocr_txt(filepath, dir_path="tmp_files"):
|
||||||
|
full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
|
||||||
|
if not os.path.exists(full_dir_path):
|
||||||
|
os.makedirs(full_dir_path)
|
||||||
|
filename = os.path.split(filepath)[-1]
|
||||||
|
ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False)
|
||||||
|
doc = fitz.open(filepath)
|
||||||
|
txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename))
|
||||||
|
img_name = os.path.join(full_dir_path, ".tmp.png")
|
||||||
|
with open(txt_file_path, "w", encoding="utf-8") as fout:
|
||||||
|
for i in range(doc.page_count):
|
||||||
|
page = doc[i]
|
||||||
|
text = page.get_text("")
|
||||||
|
fout.write(text)
|
||||||
|
fout.write("\n")
|
||||||
|
|
||||||
|
img_list = page.get_images()
|
||||||
|
for img in img_list:
|
||||||
|
pix = fitz.Pixmap(doc, img[0])
|
||||||
|
|
||||||
|
pix.save(img_name)
|
||||||
|
|
||||||
|
result = ocr.ocr(img_name)
|
||||||
|
ocr_result = [i[1][0] for line in result for i in line]
|
||||||
|
fout.write("\n".join(ocr_result))
|
||||||
|
os.remove(img_name)
|
||||||
|
return txt_file_path
|
||||||
|
|
||||||
|
txt_file_path = pdf_ocr_txt(self.file_path)
|
||||||
|
from unstructured.partition.text import partition_text
|
||||||
|
|
||||||
|
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
filepath = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test_py.pdf"
|
||||||
|
)
|
||||||
|
loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
|
||||||
|
docs = loader.load()
|
||||||
|
for doc in docs:
|
||||||
|
print(doc)
|
41
pilot/source_embedding/ppt_embedding.py
Normal file
41
pilot/source_embedding/ppt_embedding.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.document_loaders import UnstructuredPowerPointLoader
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.text_splitter import SpacyTextSplitter
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.source_embedding import SourceEmbedding, register
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class PPTEmbedding(SourceEmbedding):
|
||||||
|
"""ppt embedding for read ppt document."""
|
||||||
|
|
||||||
|
def __init__(self, file_path, vector_store_config):
|
||||||
|
"""Initialize with pdf path."""
|
||||||
|
super().__init__(file_path, vector_store_config)
|
||||||
|
self.file_path = file_path
|
||||||
|
self.vector_store_config = vector_store_config
|
||||||
|
|
||||||
|
@register
|
||||||
|
def read(self):
|
||||||
|
"""Load from ppt path."""
|
||||||
|
loader = UnstructuredPowerPointLoader(self.file_path)
|
||||||
|
textsplitter = SpacyTextSplitter(
|
||||||
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||||
|
chunk_overlap=200,
|
||||||
|
)
|
||||||
|
return loader.load_and_split(textsplitter)
|
||||||
|
|
||||||
|
@register
|
||||||
|
def data_process(self, documents: List[Document]):
|
||||||
|
i = 0
|
||||||
|
for d in documents:
|
||||||
|
documents[i].page_content = d.page_content.replace("\n", "")
|
||||||
|
i += 1
|
||||||
|
return documents
|
61
pilot/source_embedding/search_milvus.py
Normal file
61
pilot/source_embedding/search_milvus.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
# from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
# from langchain.vectorstores import Milvus
|
||||||
|
# from pymilvus import Collection,utility
|
||||||
|
# from pymilvus import connections, DataType, FieldSchema, CollectionSchema
|
||||||
|
#
|
||||||
|
# # milvus = connections.connect(
|
||||||
|
# # alias="default",
|
||||||
|
# # host='localhost',
|
||||||
|
# # port="19530"
|
||||||
|
# # )
|
||||||
|
# # collection = Collection("book")
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # Get an existing collection.
|
||||||
|
# # collection.load()
|
||||||
|
# #
|
||||||
|
# # search_params = {"metric_type": "L2", "params": {}, "offset": 5}
|
||||||
|
# #
|
||||||
|
# # results = collection.search(
|
||||||
|
# # data=[[0.1, 0.2]],
|
||||||
|
# # anns_field="book_intro",
|
||||||
|
# # param=search_params,
|
||||||
|
# # limit=10,
|
||||||
|
# # expr=None,
|
||||||
|
# # output_fields=['book_id'],
|
||||||
|
# # consistency_level="Strong"
|
||||||
|
# # )
|
||||||
|
# #
|
||||||
|
# # # get the IDs of all returned hits
|
||||||
|
# # results[0].ids
|
||||||
|
# #
|
||||||
|
# # # get the distances to the query vector from all returned hits
|
||||||
|
# # results[0].distances
|
||||||
|
# #
|
||||||
|
# # # get the value of an output field specified in the search request.
|
||||||
|
# # # vector fields are not supported yet.
|
||||||
|
# # hit = results[0][0]
|
||||||
|
# # hit.entity.get('title')
|
||||||
|
#
|
||||||
|
# # milvus = connections.connect(
|
||||||
|
# # alias="default",
|
||||||
|
# # host='localhost',
|
||||||
|
# # port="19530"
|
||||||
|
# # )
|
||||||
|
# from pilot.vector_store.milvus_store import MilvusStore
|
||||||
|
#
|
||||||
|
# data = ["aaa", "bbb"]
|
||||||
|
# model_name = "xx/all-MiniLM-L6-v2"
|
||||||
|
# embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
||||||
|
#
|
||||||
|
# # text_embeddings = Text2Vectors()
|
||||||
|
# mivuls = MilvusStore(cfg={"url": "127.0.0.1", "port": "19530", "alias": "default", "table_name": "test_k"})
|
||||||
|
#
|
||||||
|
# mivuls.insert(["textc","tezt2"])
|
||||||
|
# print("success")
|
||||||
|
# ct
|
||||||
|
# # mivuls.from_texts(texts=data, embedding=embeddings)
|
||||||
|
# # docs,
|
||||||
|
# # embedding=embeddings,
|
||||||
|
# # connection_args={"host": "127.0.0.1", "port": "19530", "alias": "default"}
|
||||||
|
# # )
|
95
pilot/source_embedding/source_embedding.py
Normal file
95
pilot/source_embedding/source_embedding.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from chromadb.errors import NotEnoughElementsException
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
registered_methods = []
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
def register(method):
|
||||||
|
registered_methods.append(method.__name__)
|
||||||
|
return method
|
||||||
|
|
||||||
|
|
||||||
|
class SourceEmbedding(ABC):
|
||||||
|
"""base class for read data source embedding pipeline.
|
||||||
|
include data read, data process, data split, data to vector, data index vector store
|
||||||
|
Implementations should implement the method
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path,
|
||||||
|
vector_store_config,
|
||||||
|
embedding_args: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||||
|
self.file_path = file_path
|
||||||
|
self.vector_store_config = vector_store_config
|
||||||
|
self.embedding_args = embedding_args
|
||||||
|
self.embeddings = vector_store_config["embeddings"]
|
||||||
|
self.vector_client = VectorStoreConnector(
|
||||||
|
CFG.VECTOR_STORE_TYPE, vector_store_config
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
@register
|
||||||
|
def read(self) -> List[ABC]:
|
||||||
|
"""read datasource into document objects."""
|
||||||
|
|
||||||
|
@register
|
||||||
|
def data_process(self, text):
|
||||||
|
"""pre process data."""
|
||||||
|
|
||||||
|
@register
|
||||||
|
def text_split(self, text):
|
||||||
|
"""text split chunk"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@register
|
||||||
|
def text_to_vector(self, docs):
|
||||||
|
"""transform vector"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@register
|
||||||
|
def index_to_store(self, docs):
|
||||||
|
"""index to vector store"""
|
||||||
|
self.vector_client.load_document(docs)
|
||||||
|
|
||||||
|
@register
|
||||||
|
def similar_search(self, doc, topk):
|
||||||
|
"""vector store similarity_search"""
|
||||||
|
try:
|
||||||
|
ans = self.vector_client.similar_search(doc, topk)
|
||||||
|
except NotEnoughElementsException:
|
||||||
|
ans = self.vector_client.similar_search(doc, 1)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def vector_name_exist(self):
|
||||||
|
return self.vector_client.vector_name_exists()
|
||||||
|
|
||||||
|
def source_embedding(self):
|
||||||
|
if "read" in registered_methods:
|
||||||
|
text = self.read()
|
||||||
|
if "data_process" in registered_methods:
|
||||||
|
text = self.data_process(text)
|
||||||
|
if "text_split" in registered_methods:
|
||||||
|
self.text_split(text)
|
||||||
|
if "text_to_vector" in registered_methods:
|
||||||
|
self.text_to_vector(text)
|
||||||
|
if "index_to_store" in registered_methods:
|
||||||
|
self.index_to_store(text)
|
||||||
|
|
||||||
|
def read_batch(self):
|
||||||
|
if "read" in registered_methods:
|
||||||
|
text = self.read()
|
||||||
|
if "data_process" in registered_methods:
|
||||||
|
text = self.data_process(text)
|
||||||
|
if "text_split" in registered_methods:
|
||||||
|
self.text_split(text)
|
||||||
|
return text
|
29
pilot/source_embedding/string_embedding.py
Normal file
29
pilot/source_embedding/string_embedding.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
from pilot import SourceEmbedding, register
|
||||||
|
|
||||||
|
|
||||||
|
class StringEmbedding(SourceEmbedding):
|
||||||
|
"""string embedding for read string document."""
|
||||||
|
|
||||||
|
def __init__(self, file_path, vector_store_config):
|
||||||
|
"""Initialize with pdf path."""
|
||||||
|
super().__init__(file_path, vector_store_config)
|
||||||
|
self.file_path = file_path
|
||||||
|
self.vector_store_config = vector_store_config
|
||||||
|
|
||||||
|
@register
|
||||||
|
def read(self):
|
||||||
|
"""Load from String path."""
|
||||||
|
metadata = {"source": "db_summary"}
|
||||||
|
return [Document(page_content=self.file_path, metadata=metadata)]
|
||||||
|
|
||||||
|
@register
|
||||||
|
def data_process(self, documents: List[Document]):
|
||||||
|
i = 0
|
||||||
|
for d in documents:
|
||||||
|
documents[i].page_content = d.page_content.replace("\n", "")
|
||||||
|
i += 1
|
||||||
|
return documents
|
49
pilot/source_embedding/url_embedding.py
Normal file
49
pilot/source_embedding/url_embedding.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from langchain.document_loaders import WebBaseLoader
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||||
|
from pilot.source_embedding import SourceEmbedding, register
|
||||||
|
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class URLEmbedding(SourceEmbedding):
|
||||||
|
"""url embedding for read url document."""
|
||||||
|
|
||||||
|
def __init__(self, file_path, vector_store_config):
|
||||||
|
"""Initialize with url path."""
|
||||||
|
super().__init__(file_path, vector_store_config)
|
||||||
|
self.file_path = file_path
|
||||||
|
self.vector_store_config = vector_store_config
|
||||||
|
|
||||||
|
@register
|
||||||
|
def read(self):
|
||||||
|
"""Load from url path."""
|
||||||
|
loader = WebBaseLoader(web_path=self.file_path)
|
||||||
|
if CFG.LANGUAGE == "en":
|
||||||
|
text_splitter = CharacterTextSplitter(
|
||||||
|
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||||
|
chunk_overlap=20,
|
||||||
|
length_function=len,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000)
|
||||||
|
return loader.load_and_split(text_splitter)
|
||||||
|
|
||||||
|
@register
|
||||||
|
def data_process(self, documents: List[Document]):
|
||||||
|
i = 0
|
||||||
|
for d in documents:
|
||||||
|
content = d.page_content.replace("\n", "")
|
||||||
|
soup = BeautifulSoup(content, "html.parser")
|
||||||
|
for tag in soup(["!doctype", "meta"]):
|
||||||
|
tag.extract()
|
||||||
|
documents[i].page_content = soup.get_text()
|
||||||
|
i += 1
|
||||||
|
return documents
|
39
pilot/source_embedding/word_embedding.py
Normal file
39
pilot/source_embedding/word_embedding.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.source_embedding import SourceEmbedding, register
|
||||||
|
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class WordEmbedding(SourceEmbedding):
|
||||||
|
"""word embedding for read word document."""
|
||||||
|
|
||||||
|
def __init__(self, file_path, vector_store_config):
|
||||||
|
"""Initialize with word path."""
|
||||||
|
super().__init__(file_path, vector_store_config)
|
||||||
|
self.file_path = file_path
|
||||||
|
self.vector_store_config = vector_store_config
|
||||||
|
|
||||||
|
@register
|
||||||
|
def read(self):
|
||||||
|
"""Load from word path."""
|
||||||
|
loader = UnstructuredWordDocumentLoader(self.file_path)
|
||||||
|
textsplitter = CHNDocumentSplitter(
|
||||||
|
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||||
|
)
|
||||||
|
return loader.load_and_split(textsplitter)
|
||||||
|
|
||||||
|
@register
|
||||||
|
def data_process(self, documents: List[Document]):
|
||||||
|
i = 0
|
||||||
|
for d in documents:
|
||||||
|
documents[i].page_content = d.page_content.replace("\n", "")
|
||||||
|
i += 1
|
||||||
|
return documents
|
Loading…
Reference in New Issue
Block a user