From beafca5c6cba9f2e82b5192e1a66c05dbe9a2fe2 Mon Sep 17 00:00:00 2001 From: yihong Date: Mon, 31 Mar 2025 19:46:44 +0800 Subject: [PATCH] fix: make tugraph work again (#2551) Signed-off-by: yihong0618 --- .../src/dbgpt_ext/datasource/conn_tugraph.py | 8 +++++-- .../dbgpt_ext/rag/summary/gdbms_db_summary.py | 22 ++++++++++++++----- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/packages/dbgpt-ext/src/dbgpt_ext/datasource/conn_tugraph.py b/packages/dbgpt-ext/src/dbgpt_ext/datasource/conn_tugraph.py index 210a72f2a..3f9183cd8 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/datasource/conn_tugraph.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/datasource/conn_tugraph.py @@ -158,11 +158,15 @@ class TuGraphConnector(BaseConnector): with self._driver.session(database=self._graph) as session: # Run the query to get vertex labels raw_vertex_labels = session.run("CALL db.vertexLabels()").data() - vertex_labels = [table_name["label"] for table_name in raw_vertex_labels] + vertex_labels = [ + table_name["label"] + "_vertex" for table_name in raw_vertex_labels + ] # Run the query to get edge labels raw_edge_labels = session.run("CALL db.edgeLabels()").data() - edge_labels = [table_name["label"] for table_name in raw_edge_labels] + edge_labels = [ + table_name["label"] + "_edge" for table_name in raw_edge_labels + ] return iter(vertex_labels + edge_labels) diff --git a/packages/dbgpt-ext/src/dbgpt_ext/rag/summary/gdbms_db_summary.py b/packages/dbgpt-ext/src/dbgpt_ext/rag/summary/gdbms_db_summary.py index aaac9b651..3eef4ec9e 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/rag/summary/gdbms_db_summary.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/rag/summary/gdbms_db_summary.py @@ -44,12 +44,14 @@ class GdbmsSummary(DBSummary): tables = self.db.get_table_names() self.table_info_summaries = { "vertex_tables": [ - self.get_table_summary(table_name, "vertex") - for table_name in tables["vertex_tables"] + self.get_table_summary(table_name.split("_")[0], "vertex") + for table_name in tables + if table_name.endswith("_vertex") ], "edge_tables": [ - self.get_table_summary(table_name, "edge") - for table_name in tables["edge_tables"] + self.get_table_summary(table_name.split("_")[0], "edge") + for table_name in tables + if table_name.endswith("_edge") ], } @@ -76,8 +78,16 @@ def _parse_db_summary( table_info_summaries = None if isinstance(conn, TuGraphConnector): table_names = conn.get_table_names() - v_tables = table_names.get("vertex_tables", []) # type: ignore - e_tables = table_names.get("edge_tables", []) # type: ignore + v_tables = [ + table_name.split("_")[0] + for table_name in table_names + if table_name.endswith("_vertex") + ] + e_tables = [ + table_name.split("_")[0] + for table_name in table_names + if table_name.endswith("_edge") + ] table_info_summaries = [ _parse_table_summary(conn, summary_template, table_name, "vertex") for table_name in v_tables