From ba7248adbb98812329d85046358a2f460c63b839 Mon Sep 17 00:00:00 2001 From: xiuzhu9527 <1406823834@QQ.COM> Date: Sat, 20 Jan 2024 05:48:48 -0600 Subject: [PATCH] [hotfix] Fix examples/awel default loading model text2vec-large-chinese issue (#1095) Co-authored-by: xiuzhu --- examples/awel/simple_dbschema_retriever_example.py | 3 ++- examples/awel/simple_rag_embedding_example.py | 3 ++- examples/awel/simple_rag_retriever_example.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/awel/simple_dbschema_retriever_example.py b/examples/awel/simple_dbschema_retriever_example.py index 744d7b763..1edbd2100 100644 --- a/examples/awel/simple_dbschema_retriever_example.py +++ b/examples/awel/simple_dbschema_retriever_example.py @@ -40,6 +40,7 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector def _create_vector_connector(): """Create vector connector.""" + model_name = os.getenv("EMBEDDING_MODEL", "text2vec") return VectorStoreConnector.from_default( "Chroma", vector_store_config=ChromaVectorConfig( @@ -47,7 +48,7 @@ def _create_vector_connector(): persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), + default_model_name=os.path.join(MODEL_PATH, model_name), ).create(), ) diff --git a/examples/awel/simple_rag_embedding_example.py b/examples/awel/simple_rag_embedding_example.py index a2a6f961b..1f1763af1 100644 --- a/examples/awel/simple_rag_embedding_example.py +++ b/examples/awel/simple_rag_embedding_example.py @@ -28,6 +28,7 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector def _create_vector_connector() -> VectorStoreConnector: """Create vector connector.""" + model_name = os.getenv("EMBEDDING_MODEL", "text2vec") return VectorStoreConnector.from_default( "Chroma", vector_store_config=ChromaVectorConfig( @@ -35,7 +36,7 @@ def _create_vector_connector() -> VectorStoreConnector: persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), + default_model_name=os.path.join(MODEL_PATH, model_name), ).create(), ) diff --git a/examples/awel/simple_rag_retriever_example.py b/examples/awel/simple_rag_retriever_example.py index 1d48d5478..4fda20281 100644 --- a/examples/awel/simple_rag_retriever_example.py +++ b/examples/awel/simple_rag_retriever_example.py @@ -75,6 +75,7 @@ def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict: def _create_vector_connector(): """Create vector connector.""" + model_name = os.getenv("EMBEDDING_MODEL", "text2vec") return VectorStoreConnector.from_default( "Chroma", vector_store_config=ChromaVectorConfig( @@ -82,7 +83,7 @@ def _create_vector_connector(): persist_path=os.path.join(PILOT_PATH, "data"), ), embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), + default_model_name=os.path.join(MODEL_PATH, model_name), ).create(), )