[hotfix] Fix examples/awel default loading model text2vec-large-chinese issue (#1095)

Co-authored-by: xiuzhu <edy@dodge-pro.local>
This commit is contained in:
xiuzhu9527
2024-01-20 05:48:48 -06:00
committed by GitHub
parent 425f4ab48b
commit ba7248adbb
3 changed files with 6 additions and 3 deletions

View File

@@ -40,6 +40,7 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
def _create_vector_connector(): def _create_vector_connector():
"""Create vector connector.""" """Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default( return VectorStoreConnector.from_default(
"Chroma", "Chroma",
vector_store_config=ChromaVectorConfig( vector_store_config=ChromaVectorConfig(
@@ -47,7 +48,7 @@ def _create_vector_connector():
persist_path=os.path.join(PILOT_PATH, "data"), persist_path=os.path.join(PILOT_PATH, "data"),
), ),
embedding_fn=DefaultEmbeddingFactory( embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), default_model_name=os.path.join(MODEL_PATH, model_name),
).create(), ).create(),
) )

View File

@@ -28,6 +28,7 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
def _create_vector_connector() -> VectorStoreConnector: def _create_vector_connector() -> VectorStoreConnector:
"""Create vector connector.""" """Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default( return VectorStoreConnector.from_default(
"Chroma", "Chroma",
vector_store_config=ChromaVectorConfig( vector_store_config=ChromaVectorConfig(
@@ -35,7 +36,7 @@ def _create_vector_connector() -> VectorStoreConnector:
persist_path=os.path.join(PILOT_PATH, "data"), persist_path=os.path.join(PILOT_PATH, "data"),
), ),
embedding_fn=DefaultEmbeddingFactory( embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), default_model_name=os.path.join(MODEL_PATH, model_name),
).create(), ).create(),
) )

View File

@@ -75,6 +75,7 @@ def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
def _create_vector_connector(): def _create_vector_connector():
"""Create vector connector.""" """Create vector connector."""
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
return VectorStoreConnector.from_default( return VectorStoreConnector.from_default(
"Chroma", "Chroma",
vector_store_config=ChromaVectorConfig( vector_store_config=ChromaVectorConfig(
@@ -82,7 +83,7 @@ def _create_vector_connector():
persist_path=os.path.join(PILOT_PATH, "data"), persist_path=os.path.join(PILOT_PATH, "data"),
), ),
embedding_fn=DefaultEmbeddingFactory( embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), default_model_name=os.path.join(MODEL_PATH, model_name),
).create(), ).create(),
) )