refactor:rename knowledge embedding api

1.replace knowledge_embedding of embedding_engine
This commit is contained in:
aries_ckt
2023-07-10 15:50:34 +08:00
parent 259a9a6e12
commit 312f243382
15 changed files with 28 additions and 30 deletions

View File

@@ -11,9 +11,9 @@ vector_store_config = {
file_path = "your file path" file_path = "your file path"
knowledge_embedding_client = KnowledgeEmbedding(file_path=file_path, model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config=vector_store_config) embedding_engine = EmbeddingEngine(file_path=file_path, model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config=vector_store_config)
knowledge_embedding_client.knowledge_embedding() embedding_engine.knowledge_embedding()
``` ```
@@ -37,7 +37,7 @@ vector_store_config = {
query = "your query" query = "your query"
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config=vector_store_config) embedding_engine = EmbeddingEngine(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config=vector_store_config)
knowledge_embedding_client.similar_search(query, 10) embedding_engine.similar_search(query, 10)
``` ```

View File

@@ -1,3 +1,4 @@
from pilot.embedding_engine.source_embedding import SourceEmbedding, register from pilot.embedding_engine.source_embedding import SourceEmbedding, register
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
__all__ = ["SourceEmbedding", "register"] __all__ = ["SourceEmbedding", "register", "EmbeddingEngine"]

View File

@@ -10,7 +10,7 @@ from pilot.vector_store.connector import VectorStoreConnector
CFG = Config() CFG = Config()
class KnowledgeEmbedding: class EmbeddingEngine:
def __init__( def __init__(
self, self,
model_name, model_name,

View File

@@ -8,7 +8,6 @@ from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSpl
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine import SourceEmbedding, register
from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter
CFG = Config() CFG = Config()

View File

@@ -2,7 +2,7 @@ from typing import List
from langchain.schema import Document from langchain.schema import Document
from pilot import SourceEmbedding, register from pilot.embedding_engine import SourceEmbedding, register
class StringEmbedding(SourceEmbedding): class StringEmbedding(SourceEmbedding):

View File

@@ -6,9 +6,7 @@ from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
from pilot.embedding_engine import SourceEmbedding, register from pilot.embedding_engine import SourceEmbedding, register
from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter
CFG = Config() CFG = Config()

View File

@@ -17,7 +17,7 @@ from pilot.configs.model_config import (
) )
from pilot.scene.chat_knowledge.custom.prompt import prompt from pilot.scene.chat_knowledge.custom.prompt import prompt
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.embedding_engine import EmbeddingEngine
CFG = Config() CFG = Config()
@@ -40,7 +40,7 @@ class ChatNewKnowledge(BaseChat):
"text_field": "content", "text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = KnowledgeEmbedding( self.knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG["text2vec"], model_name=LLM_MODEL_CONFIG["text2vec"],
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )

View File

@@ -19,7 +19,7 @@ from pilot.configs.model_config import (
) )
from pilot.scene.chat_knowledge.default.prompt import prompt from pilot.scene.chat_knowledge.default.prompt import prompt
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.embedding_engine import EmbeddingEngine
CFG = Config() CFG = Config()
@@ -40,7 +40,7 @@ class ChatDefaultKnowledge(BaseChat):
"vector_store_name": "default", "vector_store_name": "default",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = KnowledgeEmbedding( self.knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG["text2vec"], model_name=LLM_MODEL_CONFIG["text2vec"],
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )

View File

@@ -18,7 +18,7 @@ from pilot.configs.model_config import (
) )
from pilot.scene.chat_knowledge.url.prompt import prompt from pilot.scene.chat_knowledge.url.prompt import prompt
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.embedding_engine import EmbeddingEngine
CFG = Config() CFG = Config()
@@ -40,7 +40,7 @@ class ChatUrlKnowledge(BaseChat):
"vector_store_name": url.replace(":", ""), "vector_store_name": url.replace(":", ""),
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = KnowledgeEmbedding( self.knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
knowledge_type=KnowledgeType.URL.value, knowledge_type=KnowledgeType.URL.value,

View File

@@ -19,7 +19,7 @@ from pilot.configs.model_config import (
) )
from pilot.scene.chat_knowledge.v1.prompt import prompt from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.embedding_engine import EmbeddingEngine
CFG = Config() CFG = Config()
@@ -40,7 +40,7 @@ class ChatKnowledge(BaseChat):
"vector_store_name": knowledge_space, "vector_store_name": knowledge_space,
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = KnowledgeEmbedding( self.knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )

View File

@@ -10,7 +10,7 @@ from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.openapi.api_v1.api_view_model import Result from pilot.openapi.api_v1.api_view_model import Result
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.server.knowledge.service import KnowledgeService from pilot.server.knowledge.service import KnowledgeService
from pilot.server.knowledge.request.request import ( from pilot.server.knowledge.request.request import (
@@ -143,7 +143,7 @@ def document_list(space_name: str, query_request: ChunkQueryRequest):
@router.post("/knowledge/{vector_name}/query") @router.post("/knowledge/{vector_name}/query")
def similar_query(space_name: str, query_request: KnowledgeQueryRequest): def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
print(f"Received params: {space_name}, {query_request}") print(f"Received params: {space_name}, {query_request}")
client = KnowledgeEmbedding( client = EmbeddingEngine(
model_name=embeddings, vector_store_config={"vector_store_name": space_name} model_name=embeddings, vector_store_config={"vector_store_name": space_name}
) )
docs = client.similar_search(query_request.query, query_request.top_k) docs = client.similar_search(query_request.query, query_request.top_k)

View File

@@ -3,7 +3,7 @@ from datetime import datetime
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.logs import logger from pilot.logs import logger
from pilot.server.knowledge.chunk_db import ( from pilot.server.knowledge.chunk_db import (
DocumentChunkEntity, DocumentChunkEntity,
@@ -122,7 +122,7 @@ class KnowledgeService:
raise Exception( raise Exception(
f" doc:{doc.doc_name} status is {doc.status}, can not sync" f" doc:{doc.doc_name} status is {doc.status}, can not sync"
) )
client = KnowledgeEmbedding( client = EmbeddingEngine(
knowledge_source=doc.content, knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(), knowledge_type=doc.doc_type.upper(),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],

View File

@@ -37,7 +37,7 @@ from pilot.conversation import (
from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.vector_store.extract_tovec import ( from pilot.vector_store.extract_tovec import (
get_vector_storelist, get_vector_storelist,
@@ -659,7 +659,7 @@ def knowledge_embedding_store(vs_id, files):
shutil.move( shutil.move(
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename) file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
) )
knowledge_embedding_client = KnowledgeEmbedding( knowledge_embedding_client = EmbeddingEngine(
knowledge_source=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), knowledge_source=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
knowledge_type=KnowledgeType.DOCUMENT.value, knowledge_type=KnowledgeType.DOCUMENT.value,
model_name=LLM_MODEL_CONFIG["text2vec"], model_name=LLM_MODEL_CONFIG["text2vec"],

View File

@@ -7,7 +7,7 @@ from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.string_embedding import StringEmbedding from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.summary.mysql_db_summary import MysqlSummary from pilot.summary.mysql_db_summary import MysqlSummary
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
@@ -74,7 +74,7 @@ class DBSummaryClient:
vector_store_config = { vector_store_config = {
"vector_store_name": dbname + "_profile", "vector_store_name": dbname + "_profile",
} }
knowledge_embedding_client = KnowledgeEmbedding( knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
@@ -87,7 +87,7 @@ class DBSummaryClient:
vector_store_config = { vector_store_config = {
"vector_store_name": dbname + "_summary", "vector_store_name": dbname + "_summary",
} }
knowledge_embedding_client = KnowledgeEmbedding( knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
@@ -110,7 +110,7 @@ class DBSummaryClient:
vector_store_config = { vector_store_config = {
"vector_store_name": dbname + "_" + table + "_ts", "vector_store_name": dbname + "_" + table + "_ts",
} }
knowledge_embedding_client = KnowledgeEmbedding( knowledge_embedding_client = EmbeddingEngine(
file_path="", file_path="",
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config, vector_store_config=vector_store_config,

View File

@@ -16,7 +16,7 @@ from pilot.configs.model_config import (
DATASETS_DIR, DATASETS_DIR,
LLM_MODEL_CONFIG, LLM_MODEL_CONFIG,
) )
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.embedding_engine import EmbeddingEngine
knowledge_space_service = KnowledgeService() knowledge_space_service = KnowledgeService()
@@ -37,7 +37,7 @@ class LocalKnowledgeInit:
for root, _, files in os.walk(file_path, topdown=False): for root, _, files in os.walk(file_path, topdown=False):
for file in files: for file in files:
filename = os.path.join(root, file) filename = os.path.join(root, file)
ke = KnowledgeEmbedding( ke = EmbeddingEngine(
knowledge_source=filename, knowledge_source=filename,
knowledge_type=KnowledgeType.DOCUMENT.value, knowledge_type=KnowledgeType.DOCUMENT.value,
model_name=self.model_name, model_name=self.model_name,