fix: Fix examples/awel get model_name from model_config (#1112)

Co-authored-by: xiuzhu <edy@dodge-pro.local>
This commit is contained in:
xiuzhu9527 2024-01-23 22:24:34 -06:00 committed by GitHub
parent f8c0064576
commit 8f18478fa5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 8 deletions

View File

@ -3,7 +3,8 @@ from typing import Dict, List
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.rag.chunk import Chunk
@ -38,9 +39,11 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
"""
CFG = Config()
def _create_vector_connector():
"""Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
@ -48,7 +51,7 @@ def _create_vector_connector():
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name),
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
).create(),
)

View File

@ -3,7 +3,8 @@ from typing import Dict, List
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.knowledge.base import KnowledgeType
@ -25,10 +26,11 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
}'
"""
CFG = Config()
def _create_vector_connector() -> VectorStoreConnector:
"""Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
@ -36,7 +38,7 @@ def _create_vector_connector() -> VectorStoreConnector:
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name),
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
).create(),
)

View File

@ -4,7 +4,8 @@ from typing import Dict, List
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.chunk import Chunk
@ -43,6 +44,8 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
}'
"""
CFG = Config()
class TriggerReqBody(BaseModel):
query: str = Field(..., description="User query")
@ -83,7 +86,7 @@ def _create_vector_connector():
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name),
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
).create(),
)