diff --git a/examples/awel/simple_dbschema_retriever_example.py b/examples/awel/simple_dbschema_retriever_example.py index e9119fdb4..72c2dfcd3 100644 --- a/examples/awel/simple_dbschema_retriever_example.py +++ b/examples/awel/simple_dbschema_retriever_example.py @@ -3,7 +3,8 @@ from typing import Dict, List from pydantic import BaseModel, Field -from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH +from dbgpt._private.config import Config +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect from dbgpt.rag.chunk import Chunk @@ -38,9 +39,11 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector """ +CFG = Config() + + def _create_vector_connector(): """Create vector connector.""" - model_name = os.getenv("EMBEDDING_MODEL", "text2vec") return VectorStoreConnector.from_default( "Chroma", vector_store_config=ChromaVectorConfig( @@ -48,7 +51,7 @@ def _create_vector_connector(): persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join(MODEL_PATH, model_name), + default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], ).create(), ) diff --git a/examples/awel/simple_rag_embedding_example.py b/examples/awel/simple_rag_embedding_example.py index a49475acc..96d47ccc8 100644 --- a/examples/awel/simple_rag_embedding_example.py +++ b/examples/awel/simple_rag_embedding_example.py @@ -3,7 +3,8 @@ from typing import Dict, List from pydantic import BaseModel, Field -from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH +from dbgpt._private.config import Config +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH from dbgpt.core.awel import DAG, HttpTrigger, MapOperator from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory from dbgpt.rag.knowledge.base import KnowledgeType @@ -25,10 +26,11 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector }' """ +CFG = Config() + def _create_vector_connector() -> VectorStoreConnector: """Create vector connector.""" - model_name = os.getenv("EMBEDDING_MODEL", "text2vec") return VectorStoreConnector.from_default( "Chroma", vector_store_config=ChromaVectorConfig( @@ -36,7 +38,7 @@ def _create_vector_connector() -> VectorStoreConnector: persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join(MODEL_PATH, model_name), + default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], ).create(), ) diff --git a/examples/awel/simple_rag_retriever_example.py b/examples/awel/simple_rag_retriever_example.py index 5470fca36..e04f4ed0c 100644 --- a/examples/awel/simple_rag_retriever_example.py +++ b/examples/awel/simple_rag_retriever_example.py @@ -4,7 +4,8 @@ from typing import Dict, List from pydantic import BaseModel, Field -from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH +from dbgpt._private.config import Config +from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator from dbgpt.model.proxy import OpenAILLMClient from dbgpt.rag.chunk import Chunk @@ -43,6 +44,8 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector }' """ +CFG = Config() + class TriggerReqBody(BaseModel): query: str = Field(..., description="User query") @@ -83,7 +86,7 @@ def _create_vector_connector(): persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join(MODEL_PATH, model_name), + default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL], ).create(), )