fix: Fix examples/awel get model_name from model_config (#1112)

Co-authored-by: xiuzhu <edy@dodge-pro.local>
This commit is contained in:
xiuzhu9527
2024-01-23 22:24:34 -06:00
committed by GitHub
parent f8c0064576
commit 8f18478fa5
3 changed files with 16 additions and 8 deletions

View File

@@ -3,7 +3,8 @@ from typing import Dict, List
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.rag.chunk import Chunk
@@ -38,9 +39,11 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
"""
CFG = Config()
def _create_vector_connector():
"""Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
@@ -48,7 +51,7 @@ def _create_vector_connector():
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name),
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
).create(),
)

View File

@@ -3,7 +3,8 @@ from typing import Dict, List
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.knowledge.base import KnowledgeType
@@ -25,10 +26,11 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
}'
"""
CFG = Config()
def _create_vector_connector() -> VectorStoreConnector:
"""Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
@@ -36,7 +38,7 @@ def _create_vector_connector() -> VectorStoreConnector:
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name),
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
).create(),
)

View File

@@ -4,7 +4,8 @@ from typing import Dict, List
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.chunk import Chunk
@@ -43,6 +44,8 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
}'
"""
CFG = Config()
class TriggerReqBody(BaseModel):
query: str = Field(..., description="User query")
@@ -83,7 +86,7 @@ def _create_vector_connector():
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name),
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
).create(),
)