diff --git a/dbgpt/rag/evaluation/retriever.py b/dbgpt/rag/evaluation/retriever.py index 1626f923f..c1d26fc41 100644 --- a/dbgpt/rag/evaluation/retriever.py +++ b/dbgpt/rag/evaluation/retriever.py @@ -79,9 +79,8 @@ class RetrieverMRRMetric(RetrieverEvaluationMetric): def sync_compute( self, - prediction: Optional[List[str]] = None, - contexts: Optional[List[str]] = None, - **kwargs: Any, + prediction: List[str], + contexts: Optional[Sequence[str]] = None, ) -> BaseEvaluationResult: """Compute MRR metric. @@ -118,9 +117,8 @@ class RetrieverHitRateMetric(RetrieverEvaluationMetric): def sync_compute( self, - prediction: Optional[List[str]] = None, - contexts: Optional[List[str]] = None, - **kwargs: Any, + prediction: List[str], + contexts: Optional[Sequence[str]] = None, ) -> BaseEvaluationResult: """Compute HitRate metric. diff --git a/dbgpt/storage/vector_store/oceanbase_store.py b/dbgpt/storage/vector_store/oceanbase_store.py index fa58d7fdf..513a0bfc7 100644 --- a/dbgpt/storage/vector_store/oceanbase_store.py +++ b/dbgpt/storage/vector_store/oceanbase_store.py @@ -239,6 +239,7 @@ class OceanBase: def __init__( self, + database: str, connection_string: str, embedding_function: Embeddings, embedding_dimension: int = _OCEANBASE_DEFAULT_EMBEDDING_DIM, @@ -254,6 +255,7 @@ class OceanBase: enable_normalize_vector: bool = False, ) -> None: """Create OceanBase Vector Store instance.""" + self.database = database self.connection_string = connection_string self.embedding_function = embedding_function self.embedding_dimension = embedding_dimension @@ -339,6 +341,24 @@ class OceanBase: ) conn.execute(text(create_index_query)) + def check_table_exists(self) -> bool: + """Whether table `collection_name` exists.""" + check_table_query = f""" + SELECT COUNT(*) as cnt + FROM information_schema.tables + WHERE table_schema='{self.database}' AND table_name='{self.collection_name}' + """ + try: + with self.engine.connect() as conn, conn.begin(), ob_grwlock.reader_lock(): + table_exists_res = conn.execute(text(check_table_query)) + for row in table_exists_res: + return row.cnt > 0 + # No `cnt` rows? Just return False to pass `make mypy` + return False + except Exception as e: + logger.error(f"check_table_exists error: {e}") + return False + def add_texts( self, texts: Iterable[str], @@ -729,6 +749,7 @@ class OceanBaseStore(VectorStoreBase): self.collection_stat = ob_collection_stats[self.collection_name] self.vector_store_client = OceanBase( + database=self.OB_DATABASE, connection_string=self.connection_string, embedding_function=self.embeddings, collection_name=self.collection_name, @@ -769,12 +790,7 @@ class OceanBaseStore(VectorStoreBase): def vector_name_exists(self): """Whether vector name exists.""" self.logger.info("OceanBase: vector_name_exists..") - try: - self.vector_store_client.create_collection() - return True - except Exception as e: - logger.error("vector_name_exists error", e.message) - return False + return self.vector_store_client.check_table_exists() def load_document(self, chunks: List[Chunk]) -> List[str]: """Load document in vector database."""