diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 12c7e33da..819152845 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -20,6 +20,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" LLM_MODEL_CONFIG = { "flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"), "vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"), + "text2vec": os.path.join(MODEL_PATH, "text2vec"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2") } @@ -28,7 +29,7 @@ VECTOR_SEARCH_TOP_K = 3 LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 MAX_POSITION_EMBEDDINGS = 4096 -VICUNA_MODEL_SERVER = "http://121.41.167.183:8000" +VICUNA_MODEL_SERVER = "http://121.41.227.141:8000" # Load model config ISLOAD_8BIT = True diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 139caab4d..e139ff09b 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -242,10 +242,10 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re if mode == conversation_types["custome"] and not db_selector: persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb") print("向量数据库持久化地址: ", persist_dir) - knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], vector_store_config={"vector_store_name": vector_store_name["vs_name"], + knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config={"vector_store_name": vector_store_name["vs_name"], "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) query = state.messages[-2][1] - docs = knowledge_embedding_client.similar_search(query, 1) + docs = knowledge_embedding_client.similar_search(query, 10) context = [d.page_content for d in docs] prompt_template = PromptTemplate( template=conv_qa_prompt_template, @@ -254,6 +254,18 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re result = prompt_template.format(context="\n".join(context), question=query) state.messages[-2][1] = result prompt = state.get_prompt() + if len(prompt) > 4000: + logger.info("prompt length greater than 4000, rebuild") + docs = knowledge_embedding_client.similar_search(query, 5) + context = [d.page_content for d in docs] + prompt_template = PromptTemplate( + template=conv_qa_prompt_template, + input_variables=["context", "question"] + ) + result = prompt_template.format(context="\n".join(context), question=query) + state.messages[-2][1] = result + prompt = state.get_prompt() + print(len(prompt)) state.messages[-2][1] = query skip_echo_len = len(prompt.replace("", " ")) + 1 @@ -420,7 +432,7 @@ def build_single_model_ui(): max_output_tokens = gr.Slider( minimum=0, maximum=1024, - value=1024, + value=512, step=64, interactive=True, label="最大输出Token数", @@ -570,7 +582,7 @@ def knowledge_embedding_store(vs_id, files): shutil.move(file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)) knowledge_embedding_client = KnowledgeEmbedding( file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), - model_name=LLM_MODEL_CONFIG["sentence-transforms"], + model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"], "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index 617190fe5..e162aefd8 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -17,8 +17,6 @@ class PDFEmbedding(SourceEmbedding): self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config - # SourceEmbedding(file_path =file_path, ); - SourceEmbedding(file_path, model_name, vector_store_config) @register def read(self): @@ -30,7 +28,7 @@ class PDFEmbedding(SourceEmbedding): def data_process(self, documents: List[Document]): i = 0 for d in documents: - documents[i].page_content = d.page_content.replace(" ", "").replace("\n", "") + documents[i].page_content = d.page_content.replace("\n", "") i += 1 return documents diff --git a/requirements.txt b/requirements.txt index 5654dba6f..3bca421f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -72,4 +72,5 @@ chromadb markdown2 colorama playsound -distro \ No newline at end of file +distro +pypdf \ No newline at end of file