From 312f24338234fdf078ce90b736c640152abc4ac4 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Mon, 10 Jul 2023 15:50:34 +0800 Subject: [PATCH] refactor:rename knowledge embedding api 1.replace knowledge_embedding of embedding_engine --- docs/use_cases/knownledge_based_qa.md | 8 ++++---- pilot/embedding_engine/__init__.py | 3 ++- .../{knowledge_embedding.py => embedding_engine.py} | 2 +- pilot/embedding_engine/ppt_embedding.py | 1 - pilot/embedding_engine/string_embedding.py | 2 +- pilot/embedding_engine/url_embedding.py | 2 -- pilot/scene/chat_knowledge/custom/chat.py | 4 ++-- pilot/scene/chat_knowledge/default/chat.py | 4 ++-- pilot/scene/chat_knowledge/url/chat.py | 4 ++-- pilot/scene/chat_knowledge/v1/chat.py | 4 ++-- pilot/server/knowledge/api.py | 4 ++-- pilot/server/knowledge/service.py | 4 ++-- pilot/server/webserver.py | 4 ++-- pilot/summary/db_summary_client.py | 8 ++++---- tools/knowledge_init.py | 4 ++-- 15 files changed, 28 insertions(+), 30 deletions(-) rename pilot/embedding_engine/{knowledge_embedding.py => embedding_engine.py} (98%) diff --git a/docs/use_cases/knownledge_based_qa.md b/docs/use_cases/knownledge_based_qa.md index 3a357aaad..0e2731ec0 100644 --- a/docs/use_cases/knownledge_based_qa.md +++ b/docs/use_cases/knownledge_based_qa.md @@ -11,9 +11,9 @@ vector_store_config = { file_path = "your file path" -knowledge_embedding_client = KnowledgeEmbedding(file_path=file_path, model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config=vector_store_config) +embedding_engine = EmbeddingEngine(file_path=file_path, model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config=vector_store_config) -knowledge_embedding_client.knowledge_embedding() +embedding_engine.knowledge_embedding() ``` @@ -37,7 +37,7 @@ vector_store_config = { query = "your query" -knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config=vector_store_config) +embedding_engine = EmbeddingEngine(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config=vector_store_config) -knowledge_embedding_client.similar_search(query, 10) +embedding_engine.similar_search(query, 10) ``` \ No newline at end of file diff --git a/pilot/embedding_engine/__init__.py b/pilot/embedding_engine/__init__.py index ac54efd20..dbece4638 100644 --- a/pilot/embedding_engine/__init__.py +++ b/pilot/embedding_engine/__init__.py @@ -1,3 +1,4 @@ from pilot.embedding_engine.source_embedding import SourceEmbedding, register +from pilot.embedding_engine.embedding_engine import EmbeddingEngine -__all__ = ["SourceEmbedding", "register"] +__all__ = ["SourceEmbedding", "register", "EmbeddingEngine"] diff --git a/pilot/embedding_engine/knowledge_embedding.py b/pilot/embedding_engine/embedding_engine.py similarity index 98% rename from pilot/embedding_engine/knowledge_embedding.py rename to pilot/embedding_engine/embedding_engine.py index fd28b938b..066611732 100644 --- a/pilot/embedding_engine/knowledge_embedding.py +++ b/pilot/embedding_engine/embedding_engine.py @@ -10,7 +10,7 @@ from pilot.vector_store.connector import VectorStoreConnector CFG = Config() -class KnowledgeEmbedding: +class EmbeddingEngine: def __init__( self, model_name, diff --git a/pilot/embedding_engine/ppt_embedding.py b/pilot/embedding_engine/ppt_embedding.py index 219f3d0d5..a181f8d37 100644 --- a/pilot/embedding_engine/ppt_embedding.py +++ b/pilot/embedding_engine/ppt_embedding.py @@ -8,7 +8,6 @@ from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSpl from pilot.configs.config import Config from pilot.embedding_engine import SourceEmbedding, register -from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter CFG = Config() diff --git a/pilot/embedding_engine/string_embedding.py b/pilot/embedding_engine/string_embedding.py index a1d18ee82..5839290fe 100644 --- a/pilot/embedding_engine/string_embedding.py +++ b/pilot/embedding_engine/string_embedding.py @@ -2,7 +2,7 @@ from typing import List from langchain.schema import Document -from pilot import SourceEmbedding, register +from pilot.embedding_engine import SourceEmbedding, register class StringEmbedding(SourceEmbedding): diff --git a/pilot/embedding_engine/url_embedding.py b/pilot/embedding_engine/url_embedding.py index 8fa570fbc..8b8976d03 100644 --- a/pilot/embedding_engine/url_embedding.py +++ b/pilot/embedding_engine/url_embedding.py @@ -6,9 +6,7 @@ from langchain.schema import Document from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter from pilot.configs.config import Config -from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE from pilot.embedding_engine import SourceEmbedding, register -from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter CFG = Config() diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index 2c3992542..4a887164c 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -17,7 +17,7 @@ from pilot.configs.model_config import ( ) from pilot.scene.chat_knowledge.custom.prompt import prompt -from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding +from pilot.embedding_engine.embedding_engine import EmbeddingEngine CFG = Config() @@ -40,7 +40,7 @@ class ChatNewKnowledge(BaseChat): "text_field": "content", "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } - self.knowledge_embedding_client = KnowledgeEmbedding( + self.knowledge_embedding_client = EmbeddingEngine( model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config=vector_store_config, ) diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py index 972b7f88c..3b0c2fa1e 100644 --- a/pilot/scene/chat_knowledge/default/chat.py +++ b/pilot/scene/chat_knowledge/default/chat.py @@ -19,7 +19,7 @@ from pilot.configs.model_config import ( ) from pilot.scene.chat_knowledge.default.prompt import prompt -from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding +from pilot.embedding_engine.embedding_engine import EmbeddingEngine CFG = Config() @@ -40,7 +40,7 @@ class ChatDefaultKnowledge(BaseChat): "vector_store_name": "default", "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } - self.knowledge_embedding_client = KnowledgeEmbedding( + self.knowledge_embedding_client = EmbeddingEngine( model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config=vector_store_config, ) diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py index d27edf6bf..8903400a2 100644 --- a/pilot/scene/chat_knowledge/url/chat.py +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -18,7 +18,7 @@ from pilot.configs.model_config import ( ) from pilot.scene.chat_knowledge.url.prompt import prompt -from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding +from pilot.embedding_engine.embedding_engine import EmbeddingEngine CFG = Config() @@ -40,7 +40,7 @@ class ChatUrlKnowledge(BaseChat): "vector_store_name": url.replace(":", ""), "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } - self.knowledge_embedding_client = KnowledgeEmbedding( + self.knowledge_embedding_client = EmbeddingEngine( model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, knowledge_type=KnowledgeType.URL.value, diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 646672f1f..0bb80d97b 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -19,7 +19,7 @@ from pilot.configs.model_config import ( ) from pilot.scene.chat_knowledge.v1.prompt import prompt -from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding +from pilot.embedding_engine.embedding_engine import EmbeddingEngine CFG = Config() @@ -40,7 +40,7 @@ class ChatKnowledge(BaseChat): "vector_store_name": knowledge_space, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } - self.knowledge_embedding_client = KnowledgeEmbedding( + self.knowledge_embedding_client = EmbeddingEngine( model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, ) diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py index 4e5503f78..f7960f80c 100644 --- a/pilot/server/knowledge/api.py +++ b/pilot/server/knowledge/api.py @@ -10,7 +10,7 @@ from pilot.configs.config import Config from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.openapi.api_v1.api_view_model import Result -from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding +from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.server.knowledge.service import KnowledgeService from pilot.server.knowledge.request.request import ( @@ -143,7 +143,7 @@ def document_list(space_name: str, query_request: ChunkQueryRequest): @router.post("/knowledge/{vector_name}/query") def similar_query(space_name: str, query_request: KnowledgeQueryRequest): print(f"Received params: {space_name}, {query_request}") - client = KnowledgeEmbedding( + client = EmbeddingEngine( model_name=embeddings, vector_store_config={"vector_store_name": space_name} ) docs = client.similar_search(query_request.query, query_request.top_k) diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index 10c20ff31..0c13ae1cf 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -3,7 +3,7 @@ from datetime import datetime from pilot.configs.config import Config from pilot.configs.model_config import LLM_MODEL_CONFIG -from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding +from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.logs import logger from pilot.server.knowledge.chunk_db import ( DocumentChunkEntity, @@ -122,7 +122,7 @@ class KnowledgeService: raise Exception( f" doc:{doc.doc_name} status is {doc.status}, can not sync" ) - client = KnowledgeEmbedding( + client = EmbeddingEngine( knowledge_source=doc.content, knowledge_type=doc.doc_type.upper(), model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 7a88ed958..6161984bb 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -37,7 +37,7 @@ from pilot.conversation import ( from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_patch import Chatbot as grChatbot -from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding +from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.utils import build_logger from pilot.vector_store.extract_tovec import ( get_vector_storelist, @@ -659,7 +659,7 @@ def knowledge_embedding_store(vs_id, files): shutil.move( file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename) ) - knowledge_embedding_client = KnowledgeEmbedding( + knowledge_embedding_client = EmbeddingEngine( knowledge_source=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), knowledge_type=KnowledgeType.DOCUMENT.value, model_name=LLM_MODEL_CONFIG["text2vec"], diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 2a15b55c5..b1346a097 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -7,7 +7,7 @@ from pilot.configs.config import Config from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.scene.base import ChatScene from pilot.scene.base_chat import BaseChat -from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding +from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.string_embedding import StringEmbedding from pilot.summary.mysql_db_summary import MysqlSummary from pilot.scene.chat_factory import ChatFactory @@ -74,7 +74,7 @@ class DBSummaryClient: vector_store_config = { "vector_store_name": dbname + "_profile", } - knowledge_embedding_client = KnowledgeEmbedding( + knowledge_embedding_client = EmbeddingEngine( model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, ) @@ -87,7 +87,7 @@ class DBSummaryClient: vector_store_config = { "vector_store_name": dbname + "_summary", } - knowledge_embedding_client = KnowledgeEmbedding( + knowledge_embedding_client = EmbeddingEngine( model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, ) @@ -110,7 +110,7 @@ class DBSummaryClient: vector_store_config = { "vector_store_name": dbname + "_" + table + "_ts", } - knowledge_embedding_client = KnowledgeEmbedding( + knowledge_embedding_client = EmbeddingEngine( file_path="", model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, diff --git a/tools/knowledge_init.py b/tools/knowledge_init.py index a2b3b2d82..2f18fcf93 100644 --- a/tools/knowledge_init.py +++ b/tools/knowledge_init.py @@ -16,7 +16,7 @@ from pilot.configs.model_config import ( DATASETS_DIR, LLM_MODEL_CONFIG, ) -from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding +from pilot.embedding_engine.embedding_engine import EmbeddingEngine knowledge_space_service = KnowledgeService() @@ -37,7 +37,7 @@ class LocalKnowledgeInit: for root, _, files in os.walk(file_path, topdown=False): for file in files: filename = os.path.join(root, file) - ke = KnowledgeEmbedding( + ke = EmbeddingEngine( knowledge_source=filename, knowledge_type=KnowledgeType.DOCUMENT.value, model_name=self.model_name,