Vector Data visualize for Chat Data (#2172)

Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
This commit is contained in:
GITHUBear 2024-12-17 00:51:19 -08:00 committed by GitHub
parent 433550b71f
commit ed96b95efc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 170 additions and 2 deletions

View File

@ -513,7 +513,7 @@ class BaseChat(ABC):
}, },
# {"response_data_text":" the default display method, suitable for single-line or simple content display"}, # {"response_data_text":" the default display method, suitable for single-line or simple content display"},
{ {
"response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc." "response_scatter_chart": "Suitable for exploring relationships between variables, detecting outliers, etc."
}, },
{ {
"response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc." "response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."
@ -527,6 +527,9 @@ class BaseChat(ABC):
{ {
"response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc." "response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."
}, },
{
"response_vector_chart": "Suitable for projecting high-dimensional vector data onto a two-dimensional plot through the PCA algorithm."
},
] ]
return "\n".join( return "\n".join(

View File

@ -3,6 +3,8 @@ import logging
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Dict, NamedTuple from typing import Dict, NamedTuple
import numpy as np
import pandas as pd
import sqlparse import sqlparse
from dbgpt._private.config import Config from dbgpt._private.config import Config
@ -68,6 +70,52 @@ class DbChatOutputParser(BaseOutputParser):
logger.error(f"json load failed:{clean_str}") logger.error(f"json load failed:{clean_str}")
return SqlAction("", clean_str, "", "") return SqlAction("", clean_str, "", "")
def parse_vector_data_with_pca(self, df):
try:
from sklearn.decomposition import PCA
except ImportError:
raise ImportError(
"Could not import scikit-learn package. "
"Please install it with `pip install scikit-learn`."
)
nrow, ncol = df.shape
if nrow == 0 or ncol == 0:
return df, False
vec_col = -1
for i_col in range(ncol):
if isinstance(df.iloc[:, i_col][0], list):
vec_col = i_col
break
elif isinstance(df.iloc[:, i_col][0], bytes):
sample = df.iloc[:, i_col][0]
if isinstance(json.loads(sample.decode()), list):
vec_col = i_col
break
if vec_col == -1:
return df, False
vec_dim = len(json.loads(df.iloc[:, vec_col][0].decode()))
if min(nrow, vec_dim) < 2:
return df, False
df.iloc[:, vec_col] = df.iloc[:, vec_col].apply(
lambda x: json.loads(x.decode())
)
X = np.array(df.iloc[:, vec_col].tolist())
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)
new_df = pd.DataFrame()
for i_col in range(ncol):
if i_col == vec_col:
continue
col_name = df.columns[i_col]
new_df[col_name] = df[col_name]
new_df["__x"] = [pos[0] for pos in X_pca]
new_df["__y"] = [pos[1] for pos in X_pca]
return new_df, True
def parse_view_response(self, speak, data, prompt_response) -> str: def parse_view_response(self, speak, data, prompt_response) -> str:
param = {} param = {}
api_call_element = ET.Element("chart-view") api_call_element = ET.Element("chart-view")
@ -83,6 +131,11 @@ class DbChatOutputParser(BaseOutputParser):
if prompt_response.sql: if prompt_response.sql:
df = data(prompt_response.sql) df = data(prompt_response.sql)
param["type"] = prompt_response.display param["type"] = prompt_response.display
if param["type"] == "response_vector_chart":
df, visualizable = self.parse_vector_data_with_pca(df)
param["type"] = "response_scatter_chart" if visualizable else "response_table"
param["sql"] = prompt_response.sql param["sql"] = prompt_response.sql
param["data"] = json.loads( param["data"] = json.loads(
df.to_json(orient="records", date_format="iso", date_unit="s") df.to_json(orient="records", date_format="iso", date_unit="s")

View File

@ -1,12 +1,111 @@
"""OB Dialect support.""" """OB Dialect support."""
import re
from sqlalchemy import util
from sqlalchemy.dialects import registry from sqlalchemy.dialects import registry
from sqlalchemy.dialects.mysql import pymysql from sqlalchemy.dialects.mysql import pymysql
from sqlalchemy.dialects.mysql.reflection import MySQLTableDefinitionParser, _re_compile
class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser):
"""OceanBase table definition parser."""
def __init__(self, dialect, preparer, *, default_schema=None):
"""Initialize OceanBaseTableDefinitionParser."""
MySQLTableDefinitionParser.__init__(self, dialect, preparer)
self.default_schema = default_schema
def _prep_regexes(self):
super()._prep_regexes()
_final = self.preparer.final_quote
quotes = dict(
zip(
("iq", "fq", "esc_fq"),
[
re.escape(s)
for s in (
self.preparer.initial_quote,
_final,
self.preparer._escape_identifier(_final),
)
],
)
)
self._re_key = _re_compile(
r" "
r"(?:(SPATIAL|VECTOR|(?P<type>\S+)) )?KEY"
# r"(?:(?P<type>\S+) )?KEY"
r"(?: +{iq}(?P<name>(?:{esc_fq}|[^{fq}])+){fq})?"
r"(?: +USING +(?P<using_pre>\S+))?"
r" +\((?P<columns>.+?)\)"
r"(?: +USING +(?P<using_post>\S+))?"
r"(?: +(KEY_)?BLOCK_SIZE *[ =]? *(?P<keyblock>\S+) *(LOCAL)?)?"
r"(?: +WITH PARSER +(?P<parser>\S+))?"
r"(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?"
r"(?: +/\*(?P<version_sql>.+)\*/ *)?"
r",?$".format(iq=quotes["iq"], esc_fq=quotes["esc_fq"], fq=quotes["fq"])
)
kw = quotes.copy()
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION"
self._re_fk_constraint = _re_compile(
r" "
r"CONSTRAINT +"
r"{iq}(?P<name>(?:{esc_fq}|[^{fq}])+){fq} +"
r"FOREIGN KEY +"
r"\((?P<local>[^\)]+?)\) REFERENCES +"
r"(?P<table>{iq}[^{fq}]+{fq}"
r"(?:\.{iq}[^{fq}]+{fq})?) *"
r"\((?P<foreign>(?:{iq}[^{fq}]+{fq}(?: *, *)?)+)\)"
r"(?: +(?P<match>MATCH \w+))?"
r"(?: +ON UPDATE (?P<onupdate>{on}))?"
r"(?: +ON DELETE (?P<ondelete>{on}))?".format(
iq=quotes["iq"], esc_fq=quotes["esc_fq"], fq=quotes["fq"], on=kw["on"]
)
)
def _parse_constraints(self, line):
"""Parse a CONSTRAINT line."""
ret = super()._parse_constraints(line)
if ret:
tp, spec = ret
if tp == "partition":
# do not handle partition
return ret
# logger.info(f"{tp} {spec}")
if (
tp == "fk_constraint"
and len(spec["table"]) == 2
and spec["table"][0] == self.default_schema
):
spec["table"] = spec["table"][1:]
if spec.get("onupdate", "").lower() == "restrict":
spec["onupdate"] = None
if spec.get("ondelete", "").lower() == "restrict":
spec["ondelete"] = None
return ret
class OBDialect(pymysql.MySQLDialect_pymysql): class OBDialect(pymysql.MySQLDialect_pymysql):
"""OBDialect expend.""" """OBDialect expend."""
supports_statement_cache = True
def __init__(self, **kwargs):
"""Initialize OBDialect."""
try:
from pyobvector import VECTOR # type: ignore
except ImportError:
raise ImportError(
"Could not import pyobvector package. "
"Please install it with `pip install pyobvector`."
)
super().__init__(**kwargs)
self.ischema_names["VECTOR"] = VECTOR
def initialize(self, connection): def initialize(self, connection):
"""Ob dialect initialize.""" """Ob dialect initialize."""
super(OBDialect, self).initialize(connection) super(OBDialect, self).initialize(connection)
@ -22,5 +121,18 @@ class OBDialect(pymysql.MySQLDialect_pymysql):
self.server_version_info = (5, 7, 19) self.server_version_info = (5, 7, 19)
return super(OBDialect, self).get_isolation_level(dbapi_connection) return super(OBDialect, self).get_isolation_level(dbapi_connection)
@util.memoized_property
def _tabledef_parser(self):
"""Return the MySQLTableDefinitionParser, generate if needed.
The deferred creation ensures that the dialect has
retrieved server version information first.
"""
preparer = self.identifier_preparer
default_schema = self.default_schema_name
return OceanBaseTableDefinitionParser(
self, preparer, default_schema=default_schema
)
registry.register("mysql.ob", __name__, "OBDialect") registry.register("mysql.ob", __name__, "OBDialect")

View File

@ -24,7 +24,7 @@ def default_chart_type_prompt() -> str:
"non-numeric columns" "non-numeric columns"
}, },
{ {
"response_scatter_plot": "Suitable for exploring relationships between " "response_scatter_chart": "Suitable for exploring relationships between "
"variables, detecting outliers, etc." "variables, detecting outliers, etc."
}, },
{ {