[hotfix] Fix examples/awel default loading model text2vec-large-chinese issue (#1095)

Co-authored-by: xiuzhu <edy@dodge-pro.local>
This commit is contained in:
xiuzhu9527
2024-01-20 05:48:48 -06:00
committed by GitHub
parent 425f4ab48b
commit ba7248adbb
3 changed files with 6 additions and 3 deletions

View File

@@ -40,6 +40,7 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
def _create_vector_connector():
"""Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
@@ -47,7 +48,7 @@ def _create_vector_connector():
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
default_model_name=os.path.join(MODEL_PATH, model_name),
).create(),
)

View File

@@ -28,6 +28,7 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
def _create_vector_connector() -> VectorStoreConnector:
"""Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
@@ -35,7 +36,7 @@ def _create_vector_connector() -> VectorStoreConnector:
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
default_model_name=os.path.join(MODEL_PATH, model_name),
).create(),
)

View File

@@ -75,6 +75,7 @@ def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
def _create_vector_connector():
"""Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
@@ -82,7 +83,7 @@ def _create_vector_connector():
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
default_model_name=os.path.join(MODEL_PATH, model_name),
).create(),
)