Vector Data visualize for Chat Data (#2172)

Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
This commit is contained in:
GITHUBear
2024-12-17 00:51:19 -08:00
committed by GitHub
parent 433550b71f
commit ed96b95efc
4 changed files with 170 additions and 2 deletions

View File

@@ -513,7 +513,7 @@ class BaseChat(ABC):
},
# {"response_data_text":" the default display method, suitable for single-line or simple content display"},
{
"response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc."
"response_scatter_chart": "Suitable for exploring relationships between variables, detecting outliers, etc."
},
{
"response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."
@@ -527,6 +527,9 @@ class BaseChat(ABC):
{
"response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."
},
{
"response_vector_chart": "Suitable for projecting high-dimensional vector data onto a two-dimensional plot through the PCA algorithm."
},
]
return "\n".join(

View File

@@ -3,6 +3,8 @@ import logging
import xml.etree.ElementTree as ET
from typing import Dict, NamedTuple
import numpy as np
import pandas as pd
import sqlparse
from dbgpt._private.config import Config
@@ -68,6 +70,52 @@ class DbChatOutputParser(BaseOutputParser):
logger.error(f"json load failed:{clean_str}")
return SqlAction("", clean_str, "", "")
def parse_vector_data_with_pca(self, df):
try:
from sklearn.decomposition import PCA
except ImportError:
raise ImportError(
"Could not import scikit-learn package. "
"Please install it with `pip install scikit-learn`."
)
nrow, ncol = df.shape
if nrow == 0 or ncol == 0:
return df, False
vec_col = -1
for i_col in range(ncol):
if isinstance(df.iloc[:, i_col][0], list):
vec_col = i_col
break
elif isinstance(df.iloc[:, i_col][0], bytes):
sample = df.iloc[:, i_col][0]
if isinstance(json.loads(sample.decode()), list):
vec_col = i_col
break
if vec_col == -1:
return df, False
vec_dim = len(json.loads(df.iloc[:, vec_col][0].decode()))
if min(nrow, vec_dim) < 2:
return df, False
df.iloc[:, vec_col] = df.iloc[:, vec_col].apply(
lambda x: json.loads(x.decode())
)
X = np.array(df.iloc[:, vec_col].tolist())
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)
new_df = pd.DataFrame()
for i_col in range(ncol):
if i_col == vec_col:
continue
col_name = df.columns[i_col]
new_df[col_name] = df[col_name]
new_df["__x"] = [pos[0] for pos in X_pca]
new_df["__y"] = [pos[1] for pos in X_pca]
return new_df, True
def parse_view_response(self, speak, data, prompt_response) -> str:
param = {}
api_call_element = ET.Element("chart-view")
@@ -83,6 +131,11 @@ class DbChatOutputParser(BaseOutputParser):
if prompt_response.sql:
df = data(prompt_response.sql)
param["type"] = prompt_response.display
if param["type"] == "response_vector_chart":
df, visualizable = self.parse_vector_data_with_pca(df)
param["type"] = "response_scatter_chart" if visualizable else "response_table"
param["sql"] = prompt_response.sql
param["data"] = json.loads(
df.to_json(orient="records", date_format="iso", date_unit="s")