mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 20:26:15 +00:00
Vector Data visualize for Chat Data (#2172)
Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
This commit is contained in:
parent
433550b71f
commit
ed96b95efc
@ -513,7 +513,7 @@ class BaseChat(ABC):
|
|||||||
},
|
},
|
||||||
# {"response_data_text":" the default display method, suitable for single-line or simple content display"},
|
# {"response_data_text":" the default display method, suitable for single-line or simple content display"},
|
||||||
{
|
{
|
||||||
"response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc."
|
"response_scatter_chart": "Suitable for exploring relationships between variables, detecting outliers, etc."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."
|
"response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."
|
||||||
@ -527,6 +527,9 @@ class BaseChat(ABC):
|
|||||||
{
|
{
|
||||||
"response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."
|
"response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"response_vector_chart": "Suitable for projecting high-dimensional vector data onto a two-dimensional plot through the PCA algorithm."
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
return "\n".join(
|
return "\n".join(
|
||||||
|
@ -3,6 +3,8 @@ import logging
|
|||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from typing import Dict, NamedTuple
|
from typing import Dict, NamedTuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import sqlparse
|
import sqlparse
|
||||||
|
|
||||||
from dbgpt._private.config import Config
|
from dbgpt._private.config import Config
|
||||||
@ -68,6 +70,52 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
logger.error(f"json load failed:{clean_str}")
|
logger.error(f"json load failed:{clean_str}")
|
||||||
return SqlAction("", clean_str, "", "")
|
return SqlAction("", clean_str, "", "")
|
||||||
|
|
||||||
|
def parse_vector_data_with_pca(self, df):
|
||||||
|
try:
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import scikit-learn package. "
|
||||||
|
"Please install it with `pip install scikit-learn`."
|
||||||
|
)
|
||||||
|
|
||||||
|
nrow, ncol = df.shape
|
||||||
|
if nrow == 0 or ncol == 0:
|
||||||
|
return df, False
|
||||||
|
|
||||||
|
vec_col = -1
|
||||||
|
for i_col in range(ncol):
|
||||||
|
if isinstance(df.iloc[:, i_col][0], list):
|
||||||
|
vec_col = i_col
|
||||||
|
break
|
||||||
|
elif isinstance(df.iloc[:, i_col][0], bytes):
|
||||||
|
sample = df.iloc[:, i_col][0]
|
||||||
|
if isinstance(json.loads(sample.decode()), list):
|
||||||
|
vec_col = i_col
|
||||||
|
break
|
||||||
|
if vec_col == -1:
|
||||||
|
return df, False
|
||||||
|
vec_dim = len(json.loads(df.iloc[:, vec_col][0].decode()))
|
||||||
|
if min(nrow, vec_dim) < 2:
|
||||||
|
return df, False
|
||||||
|
df.iloc[:, vec_col] = df.iloc[:, vec_col].apply(
|
||||||
|
lambda x: json.loads(x.decode())
|
||||||
|
)
|
||||||
|
X = np.array(df.iloc[:, vec_col].tolist())
|
||||||
|
|
||||||
|
pca = PCA(n_components=2)
|
||||||
|
X_pca = pca.fit_transform(X)
|
||||||
|
|
||||||
|
new_df = pd.DataFrame()
|
||||||
|
for i_col in range(ncol):
|
||||||
|
if i_col == vec_col:
|
||||||
|
continue
|
||||||
|
col_name = df.columns[i_col]
|
||||||
|
new_df[col_name] = df[col_name]
|
||||||
|
new_df["__x"] = [pos[0] for pos in X_pca]
|
||||||
|
new_df["__y"] = [pos[1] for pos in X_pca]
|
||||||
|
return new_df, True
|
||||||
|
|
||||||
def parse_view_response(self, speak, data, prompt_response) -> str:
|
def parse_view_response(self, speak, data, prompt_response) -> str:
|
||||||
param = {}
|
param = {}
|
||||||
api_call_element = ET.Element("chart-view")
|
api_call_element = ET.Element("chart-view")
|
||||||
@ -83,6 +131,11 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
if prompt_response.sql:
|
if prompt_response.sql:
|
||||||
df = data(prompt_response.sql)
|
df = data(prompt_response.sql)
|
||||||
param["type"] = prompt_response.display
|
param["type"] = prompt_response.display
|
||||||
|
|
||||||
|
if param["type"] == "response_vector_chart":
|
||||||
|
df, visualizable = self.parse_vector_data_with_pca(df)
|
||||||
|
param["type"] = "response_scatter_chart" if visualizable else "response_table"
|
||||||
|
|
||||||
param["sql"] = prompt_response.sql
|
param["sql"] = prompt_response.sql
|
||||||
param["data"] = json.loads(
|
param["data"] = json.loads(
|
||||||
df.to_json(orient="records", date_format="iso", date_unit="s")
|
df.to_json(orient="records", date_format="iso", date_unit="s")
|
||||||
|
@ -1,12 +1,111 @@
|
|||||||
"""OB Dialect support."""
|
"""OB Dialect support."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from sqlalchemy import util
|
||||||
from sqlalchemy.dialects import registry
|
from sqlalchemy.dialects import registry
|
||||||
from sqlalchemy.dialects.mysql import pymysql
|
from sqlalchemy.dialects.mysql import pymysql
|
||||||
|
from sqlalchemy.dialects.mysql.reflection import MySQLTableDefinitionParser, _re_compile
|
||||||
|
|
||||||
|
|
||||||
|
class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser):
|
||||||
|
"""OceanBase table definition parser."""
|
||||||
|
|
||||||
|
def __init__(self, dialect, preparer, *, default_schema=None):
|
||||||
|
"""Initialize OceanBaseTableDefinitionParser."""
|
||||||
|
MySQLTableDefinitionParser.__init__(self, dialect, preparer)
|
||||||
|
self.default_schema = default_schema
|
||||||
|
|
||||||
|
def _prep_regexes(self):
|
||||||
|
super()._prep_regexes()
|
||||||
|
|
||||||
|
_final = self.preparer.final_quote
|
||||||
|
quotes = dict(
|
||||||
|
zip(
|
||||||
|
("iq", "fq", "esc_fq"),
|
||||||
|
[
|
||||||
|
re.escape(s)
|
||||||
|
for s in (
|
||||||
|
self.preparer.initial_quote,
|
||||||
|
_final,
|
||||||
|
self.preparer._escape_identifier(_final),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._re_key = _re_compile(
|
||||||
|
r" "
|
||||||
|
r"(?:(SPATIAL|VECTOR|(?P<type>\S+)) )?KEY"
|
||||||
|
# r"(?:(?P<type>\S+) )?KEY"
|
||||||
|
r"(?: +{iq}(?P<name>(?:{esc_fq}|[^{fq}])+){fq})?"
|
||||||
|
r"(?: +USING +(?P<using_pre>\S+))?"
|
||||||
|
r" +\((?P<columns>.+?)\)"
|
||||||
|
r"(?: +USING +(?P<using_post>\S+))?"
|
||||||
|
r"(?: +(KEY_)?BLOCK_SIZE *[ =]? *(?P<keyblock>\S+) *(LOCAL)?)?"
|
||||||
|
r"(?: +WITH PARSER +(?P<parser>\S+))?"
|
||||||
|
r"(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?"
|
||||||
|
r"(?: +/\*(?P<version_sql>.+)\*/ *)?"
|
||||||
|
r",?$".format(iq=quotes["iq"], esc_fq=quotes["esc_fq"], fq=quotes["fq"])
|
||||||
|
)
|
||||||
|
|
||||||
|
kw = quotes.copy()
|
||||||
|
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION"
|
||||||
|
self._re_fk_constraint = _re_compile(
|
||||||
|
r" "
|
||||||
|
r"CONSTRAINT +"
|
||||||
|
r"{iq}(?P<name>(?:{esc_fq}|[^{fq}])+){fq} +"
|
||||||
|
r"FOREIGN KEY +"
|
||||||
|
r"\((?P<local>[^\)]+?)\) REFERENCES +"
|
||||||
|
r"(?P<table>{iq}[^{fq}]+{fq}"
|
||||||
|
r"(?:\.{iq}[^{fq}]+{fq})?) *"
|
||||||
|
r"\((?P<foreign>(?:{iq}[^{fq}]+{fq}(?: *, *)?)+)\)"
|
||||||
|
r"(?: +(?P<match>MATCH \w+))?"
|
||||||
|
r"(?: +ON UPDATE (?P<onupdate>{on}))?"
|
||||||
|
r"(?: +ON DELETE (?P<ondelete>{on}))?".format(
|
||||||
|
iq=quotes["iq"], esc_fq=quotes["esc_fq"], fq=quotes["fq"], on=kw["on"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_constraints(self, line):
|
||||||
|
"""Parse a CONSTRAINT line."""
|
||||||
|
ret = super()._parse_constraints(line)
|
||||||
|
if ret:
|
||||||
|
tp, spec = ret
|
||||||
|
if tp == "partition":
|
||||||
|
# do not handle partition
|
||||||
|
return ret
|
||||||
|
# logger.info(f"{tp} {spec}")
|
||||||
|
if (
|
||||||
|
tp == "fk_constraint"
|
||||||
|
and len(spec["table"]) == 2
|
||||||
|
and spec["table"][0] == self.default_schema
|
||||||
|
):
|
||||||
|
spec["table"] = spec["table"][1:]
|
||||||
|
if spec.get("onupdate", "").lower() == "restrict":
|
||||||
|
spec["onupdate"] = None
|
||||||
|
if spec.get("ondelete", "").lower() == "restrict":
|
||||||
|
spec["ondelete"] = None
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class OBDialect(pymysql.MySQLDialect_pymysql):
|
class OBDialect(pymysql.MySQLDialect_pymysql):
|
||||||
"""OBDialect expend."""
|
"""OBDialect expend."""
|
||||||
|
|
||||||
|
supports_statement_cache = True
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""Initialize OBDialect."""
|
||||||
|
try:
|
||||||
|
from pyobvector import VECTOR # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import pyobvector package. "
|
||||||
|
"Please install it with `pip install pyobvector`."
|
||||||
|
)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.ischema_names["VECTOR"] = VECTOR
|
||||||
|
|
||||||
def initialize(self, connection):
|
def initialize(self, connection):
|
||||||
"""Ob dialect initialize."""
|
"""Ob dialect initialize."""
|
||||||
super(OBDialect, self).initialize(connection)
|
super(OBDialect, self).initialize(connection)
|
||||||
@ -22,5 +121,18 @@ class OBDialect(pymysql.MySQLDialect_pymysql):
|
|||||||
self.server_version_info = (5, 7, 19)
|
self.server_version_info = (5, 7, 19)
|
||||||
return super(OBDialect, self).get_isolation_level(dbapi_connection)
|
return super(OBDialect, self).get_isolation_level(dbapi_connection)
|
||||||
|
|
||||||
|
@util.memoized_property
|
||||||
|
def _tabledef_parser(self):
|
||||||
|
"""Return the MySQLTableDefinitionParser, generate if needed.
|
||||||
|
|
||||||
|
The deferred creation ensures that the dialect has
|
||||||
|
retrieved server version information first.
|
||||||
|
"""
|
||||||
|
preparer = self.identifier_preparer
|
||||||
|
default_schema = self.default_schema_name
|
||||||
|
return OceanBaseTableDefinitionParser(
|
||||||
|
self, preparer, default_schema=default_schema
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
registry.register("mysql.ob", __name__, "OBDialect")
|
registry.register("mysql.ob", __name__, "OBDialect")
|
||||||
|
@ -24,7 +24,7 @@ def default_chart_type_prompt() -> str:
|
|||||||
"non-numeric columns"
|
"non-numeric columns"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"response_scatter_plot": "Suitable for exploring relationships between "
|
"response_scatter_chart": "Suitable for exploring relationships between "
|
||||||
"variables, detecting outliers, etc."
|
"variables, detecting outliers, etc."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user