fix: Fix examples/awel get model_name from model_config (#1112)

Co-authored-by: xiuzhu <edy@dodge-pro.local>
This commit is contained in:
xiuzhu9527 2024-01-23 22:24:34 -06:00 committed by GitHub
parent f8c0064576
commit 8f18478fa5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 8 deletions

View File

@ -3,7 +3,8 @@ from typing import Dict, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.rag.chunk import Chunk from dbgpt.rag.chunk import Chunk
@ -38,9 +39,11 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
""" """
CFG = Config()
def _create_vector_connector(): def _create_vector_connector():
"""Create vector connector.""" """Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default( return VectorStoreConnector.from_default(
"Chroma", "Chroma",
vector_store_config=ChromaVectorConfig( vector_store_config=ChromaVectorConfig(
@ -48,7 +51,7 @@ def _create_vector_connector():
persist_path=os.path.join(PILOT_PATH, "data"), persist_path=os.path.join(PILOT_PATH, "data"),
), ),
embedding_fn=DefaultEmbeddingFactory( embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name), default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
).create(), ).create(),
) )

View File

@ -3,7 +3,8 @@ from typing import Dict, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.knowledge.base import KnowledgeType from dbgpt.rag.knowledge.base import KnowledgeType
@ -25,10 +26,11 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
}' }'
""" """
CFG = Config()
def _create_vector_connector() -> VectorStoreConnector: def _create_vector_connector() -> VectorStoreConnector:
"""Create vector connector.""" """Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default( return VectorStoreConnector.from_default(
"Chroma", "Chroma",
vector_store_config=ChromaVectorConfig( vector_store_config=ChromaVectorConfig(
@ -36,7 +38,7 @@ def _create_vector_connector() -> VectorStoreConnector:
persist_path=os.path.join(PILOT_PATH, "data"), persist_path=os.path.join(PILOT_PATH, "data"),
), ),
embedding_fn=DefaultEmbeddingFactory( embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name), default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
).create(), ).create(),
) )

View File

@ -4,7 +4,8 @@ from typing import Dict, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH from dbgpt._private.config import Config
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.model.proxy import OpenAILLMClient from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.chunk import Chunk from dbgpt.rag.chunk import Chunk
@ -43,6 +44,8 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
}' }'
""" """
CFG = Config()
class TriggerReqBody(BaseModel): class TriggerReqBody(BaseModel):
query: str = Field(..., description="User query") query: str = Field(..., description="User query")
@ -83,7 +86,7 @@ def _create_vector_connector():
persist_path=os.path.join(PILOT_PATH, "data"), persist_path=os.path.join(PILOT_PATH, "data"),
), ),
embedding_fn=DefaultEmbeddingFactory( embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, model_name), default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
).create(), ).create(),
) )