fix:url embedding

This commit is contained in:
aries-ckt
2023-06-01 22:07:33 +08:00
parent 8dd25815e1
commit 1d432e4d29
3 changed files with 18 additions and 8 deletions

View File

@@ -53,9 +53,11 @@ class BaseOutputParser(ABC):
""" """
if data["error_code"] == 0: if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL: if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len + 11:].strip() # output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip()
elif "guanaco" in CFG.LLM_MODEL: elif "guanaco" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip() # output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip()
output = data["text"][skip_echo_len:].replace("<s>", "").strip()
else: else:
output = data["text"].strip() output = data["text"].strip()

View File

@@ -11,7 +11,7 @@ from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
CFG = Config() CFG = Config()
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. PROMPT_SCENE_DEFINE = """A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge.
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """ The assistant gives helpful, detailed, professional and polite answers to the user's questions. """

View File

@@ -5,9 +5,12 @@ from langchain.document_loaders import WebBaseLoader
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import CharacterTextSplitter from langchain.text_splitter import CharacterTextSplitter
from pilot.configs.config import Config
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
from pilot.source_embedding import SourceEmbedding, register from pilot.source_embedding import SourceEmbedding, register
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
CFG = Config()
class URLEmbedding(SourceEmbedding): class URLEmbedding(SourceEmbedding):
"""url embedding for read url document.""" """url embedding for read url document."""
@@ -22,10 +25,15 @@ class URLEmbedding(SourceEmbedding):
def read(self): def read(self):
"""Load from url path.""" """Load from url path."""
loader = WebBaseLoader(web_path=self.file_path) loader = WebBaseLoader(web_path=self.file_path)
text_splitor = CharacterTextSplitter( if CFG.LANGUAGE == "en":
chunk_size=100, chunk_overlap=20, length_function=len text_splitter = CharacterTextSplitter(
) chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE, chunk_overlap=20, length_function=len
return loader.load_and_split(text_splitor) )
else:
text_splitter = CHNDocumentSplitter(
pdf=True, sentence_size=1000
)
return loader.load_and_split(text_splitter)
@register @register
def data_process(self, documents: List[Document]): def data_process(self, documents: List[Document]):