Support attaching an Excel spreadsheet to a chat message (#3007)

Signed-off-by: Adam Treat <treat.adam@gmail.com>
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
AT
2024-10-01 21:17:49 -04:00
committed by GitHub
parent c11b67dfcb
commit db443f2090
28 changed files with 855 additions and 205 deletions

View File

@@ -5,6 +5,7 @@
#include "network.h"
#include "server.h"
#include <QBuffer>
#include <QDataStream>
#include <QDebug>
#include <QLatin1String>
@@ -122,6 +123,42 @@ void Chat::resetResponseState()
emit responseStateChanged();
}
void Chat::newPromptResponsePair(const QString &prompt, const QList<QUrl> &attachedUrls)
{
QStringList attachedContexts;
QList<PromptAttachment> attachments;
for (const QUrl &url : attachedUrls) {
Q_ASSERT(url.isLocalFile());
const QString localFilePath = url.toLocalFile();
const QFileInfo info(localFilePath);
Q_ASSERT(info.suffix() == "xlsx"); // We only support excel right now
PromptAttachment attached;
attached.url = url;
QFile file(localFilePath);
if (file.open(QIODevice::ReadOnly)) {
attached.content = file.readAll();
file.close();
} else {
qWarning() << "ERROR: Failed to open the attachment:" << localFilePath;
continue;
}
attachments << attached;
attachedContexts << attached.processedContent();
}
QString promptPlusAttached = prompt;
if (!attachedContexts.isEmpty())
promptPlusAttached = attachedContexts.join("\n\n") + "\n\n" + prompt;
newPromptResponsePairInternal(prompt, attachments);
emit resetResponseRequested();
this->prompt(promptPlusAttached);
}
void Chat::prompt(const QString &prompt)
{
resetResponseState();
@@ -232,23 +269,17 @@ void Chat::setModelInfo(const ModelInfo &modelInfo)
emit modelChangeRequested(modelInfo);
}
void Chat::newPromptResponsePair(const QString &prompt)
// the server needs to block until response is reset, so it calls resetResponse on its own m_llmThread
void Chat::serverNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments)
{
resetResponseState();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
// the prompt is passed as the prompt item's value and the response item's prompt
m_chatModel->appendPrompt("Prompt: ", prompt);
m_chatModel->appendResponse("Response: ");
emit resetResponseRequested();
newPromptResponsePairInternal(prompt, attachments);
}
// the server needs to block until response is reset, so it calls resetResponse on its own m_llmThread
void Chat::serverNewPromptResponsePair(const QString &prompt)
void Chat::newPromptResponsePairInternal(const QString &prompt, const QList<PromptAttachment> &attachments)
{
resetResponseState();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
// the prompt is passed as the prompt item's value and the response item's prompt
m_chatModel->appendPrompt("Prompt: ", prompt);
m_chatModel->appendPrompt("Prompt: ", prompt, attachments);
m_chatModel->appendResponse("Response: ");
}

View File

@@ -77,10 +77,10 @@ public:
bool isModelLoaded() const { return m_modelLoadingPercentage == 1.0f; }
bool isCurrentlyLoading() const { return m_modelLoadingPercentage > 0.0f && m_modelLoadingPercentage < 1.0f; }
float modelLoadingPercentage() const { return m_modelLoadingPercentage; }
Q_INVOKABLE void newPromptResponsePair(const QString &prompt, const QList<QUrl> &attachedUrls = {});
Q_INVOKABLE void prompt(const QString &prompt);
Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void stopGenerating();
Q_INVOKABLE void newPromptResponsePair(const QString &prompt);
QList<ResultInfo> databaseResults() const { return m_databaseResults; }
@@ -125,7 +125,7 @@ public:
QList<QString> generatedQuestions() const { return m_generatedQuestions; }
public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt);
void serverNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments = {});
Q_SIGNALS:
void idChanged(const QString &id);
@@ -174,6 +174,9 @@ private Q_SLOTS:
void handleModelInfoChanged(const ModelInfo &modelInfo);
void handleTrySwitchContextOfLoadedModelCompleted(int value);
private:
void newPromptResponsePairInternal(const QString &prompt, const QList<PromptAttachment> &attachments);
private:
QString m_id;
QString m_name;

View File

@@ -1333,7 +1333,7 @@ void ChatLLM::processRestoreStateFromText()
// FIXME(jared): this doesn't work well with the "regenerate" button since we are not incrementing
// m_promptTokens or m_promptResponseTokens
m_llModelInfo.model->prompt(
prompt.value.toStdString(), promptTemplate.toStdString(),
prompt.promptPlusAttachments().toStdString(), promptTemplate.toStdString(),
promptFunc, /*responseFunc*/ [](auto &&...) { return true; },
/*allowContextShift*/ true,
m_ctx,

View File

@@ -2,8 +2,10 @@
#define CHATMODEL_H
#include "database.h"
#include "xlsxtomd.h"
#include <QAbstractListModel>
#include <QBuffer>
#include <QByteArray>
#include <QDataStream>
#include <QHash>
@@ -16,6 +18,40 @@
#include <Qt>
#include <QtGlobal>
struct PromptAttachment {
Q_GADGET
Q_PROPERTY(QUrl url MEMBER url)
Q_PROPERTY(QByteArray content MEMBER content)
Q_PROPERTY(QString file READ file)
Q_PROPERTY(QString processedContent READ processedContent)
public:
QUrl url;
QByteArray content;
QString file() const
{
if (!url.isLocalFile())
return QString();
const QString localFilePath = url.toLocalFile();
const QFileInfo info(localFilePath);
return info.fileName();
}
QString processedContent() const
{
QBuffer buffer;
buffer.setData(content);
buffer.open(QIODevice::ReadOnly);
const QString md = XLSXToMD::toMarkdown(&buffer);
buffer.close();
return md;
}
bool operator==(const PromptAttachment &other) const { return url == other.url; }
};
Q_DECLARE_METATYPE(PromptAttachment)
struct ChatItem
{
Q_GADGET
@@ -29,8 +65,22 @@ struct ChatItem
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)
Q_PROPERTY(QList<ResultInfo> sources MEMBER sources)
Q_PROPERTY(QList<ResultInfo> consolidatedSources MEMBER consolidatedSources)
Q_PROPERTY(QList<PromptAttachment> promptAttachments MEMBER promptAttachments);
Q_PROPERTY(QString promptPlusAttachments READ promptPlusAttachments);
public:
QString promptPlusAttachments() const
{
QStringList attachedContexts;
for (auto attached : promptAttachments)
attachedContexts << attached.processedContent();
QString promptPlus = value;
if (!attachedContexts.isEmpty())
promptPlus = attachedContexts.join("\n\n") + "\n\n" + value;
return promptPlus;
}
// TODO: Maybe we should include the model name here as well as timestamp?
int id = 0;
QString name;
@@ -38,6 +88,7 @@ public:
QString newResponse;
QList<ResultInfo> sources;
QList<ResultInfo> consolidatedSources;
QList<PromptAttachment> promptAttachments;
bool currentResponse = false;
bool stopped = false;
bool thumbsUpState = false;
@@ -65,7 +116,8 @@ public:
ThumbsUpStateRole,
ThumbsDownStateRole,
SourcesRole,
ConsolidatedSourcesRole
ConsolidatedSourcesRole,
PromptAttachmentsRole
};
int rowCount(const QModelIndex &parent = QModelIndex()) const override
@@ -103,6 +155,8 @@ public:
return QVariant::fromValue(item.sources);
case ConsolidatedSourcesRole:
return QVariant::fromValue(item.consolidatedSources);
case PromptAttachmentsRole:
return QVariant::fromValue(item.promptAttachments);
}
return QVariant();
@@ -121,14 +175,17 @@ public:
roles[ThumbsDownStateRole] = "thumbsDownState";
roles[SourcesRole] = "sources";
roles[ConsolidatedSourcesRole] = "consolidatedSources";
roles[PromptAttachmentsRole] = "promptAttachments";
return roles;
}
void appendPrompt(const QString &name, const QString &value)
void appendPrompt(const QString &name, const QString &value, const QList<PromptAttachment> &attachments)
{
ChatItem item;
item.name = name;
item.value = value;
item.promptAttachments << attachments;
m_mutex.lock();
const int count = m_chatItems.count();
m_mutex.unlock();
@@ -380,6 +437,14 @@ public:
stream << references.join("\n");
stream << referencesContext;
}
if (version >= 10) {
stream << c.promptAttachments.size();
for (const PromptAttachment &a : c.promptAttachments) {
Q_ASSERT(!a.url.isEmpty());
stream << a.url;
stream << a.content;
}
}
}
return stream.status() == QDataStream::Ok;
}
@@ -423,7 +488,7 @@ public:
}
c.sources = sources;
c.consolidatedSources = consolidateSources(sources);
}else if (version > 2) {
} else if (version > 2) {
QString references;
QList<QString> referencesContext;
stream >> references;
@@ -507,6 +572,18 @@ public:
c.consolidatedSources = consolidateSources(sources);
}
}
if (version >= 10) {
qsizetype count;
stream >> count;
QList<PromptAttachment> attachments;
for (int i = 0; i < count; ++i) {
PromptAttachment a;
stream >> a.url;
stream >> a.content;
attachments.append(a);
}
c.promptAttachments = attachments;
}
m_mutex.lock();
const int count = m_chatItems.size();
m_mutex.unlock();

View File

@@ -2,6 +2,7 @@
#define SERVER_H
#include "chatllm.h"
#include "chatmodel.h"
#include "database.h"
#include <QHttpServer>
@@ -32,7 +33,7 @@ public Q_SLOTS:
void start();
Q_SIGNALS:
void requestServerNewPromptResponsePair(const QString &prompt);
void requestServerNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments = {});
private:
auto handleCompletionRequest(const CompletionRequest &request) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>;

View File

@@ -0,0 +1,167 @@
#include "xlsxtomd.h"
#include <xlsxabstractsheet.h>
#include <xlsxcell.h>
#include <xlsxcellrange.h>
#include <xlsxdocument.h>
#include <xlsxformat.h>
#include <xlsxworksheet.h>
#include <QDateTime>
#include <QDebug>
#include <QList>
#include <QString>
#include <QStringList>
#include <QVariant>
#include <QtGlobal>
#include <QtLogging>
#include <memory>
using namespace Qt::Literals::StringLiterals;
static QString formatCellText(const QXlsx::Cell *cell)
{
if (!cell) return QString();
QVariant value = cell->value();
QXlsx::Format format = cell->format();
QString cellText;
// Determine the cell type based on format
if (format.isDateTimeFormat()) {
// Handle DateTime
QDateTime dateTime = value.toDateTime();
cellText = dateTime.isValid() ? dateTime.toString("yyyy-MM-dd") : value.toString();
} else {
cellText = value.toString();
}
if (cellText.isEmpty())
return QString();
// Apply Markdown and HTML formatting based on font styles
QString formattedText = cellText;
if (format.fontBold() && format.fontItalic())
formattedText = "***" + formattedText + "***";
else if (format.fontBold())
formattedText = "**" + formattedText + "**";
else if (format.fontItalic())
formattedText = "*" + formattedText + "*";
if (format.fontStrikeOut())
formattedText = "~~" + formattedText + "~~";
// Escape pipe characters to prevent Markdown table issues
formattedText.replace("|", "\\|");
return formattedText;
}
static QString getCellValue(QXlsx::Worksheet *sheet, int row, int col)
{
if (!sheet)
return QString();
// Attempt to retrieve the cell directly
std::shared_ptr<QXlsx::Cell> cell = sheet->cellAt(row, col);
// If the cell is part of a merged range and not directly available
if (!cell) {
for (const QXlsx::CellRange &range : sheet->mergedCells()) {
if (row >= range.firstRow() && row <= range.lastRow() &&
col >= range.firstColumn() && col <= range.lastColumn()) {
cell = sheet->cellAt(range.firstRow(), range.firstColumn());
break;
}
}
}
// Format and return the cell text if available
if (cell)
return formatCellText(cell.get());
// Return empty string if cell is not found
return QString();
}
QString XLSXToMD::toMarkdown(QIODevice *xlsxDevice)
{
// Load the Excel document
QXlsx::Document xlsx(xlsxDevice);
if (!xlsx.load()) {
qCritical() << "Failed to load the Excel from device";
return QString();
}
QString markdown;
// Retrieve all sheet names
QStringList sheetNames = xlsx.sheetNames();
if (sheetNames.isEmpty()) {
qWarning() << "No sheets found in the Excel document.";
return QString();
}
// Iterate through each worksheet by name
for (const QString &sheetName : sheetNames) {
QXlsx::Worksheet *sheet = dynamic_cast<QXlsx::Worksheet *>(xlsx.sheet(sheetName));
if (!sheet) {
qWarning() << "Failed to load sheet:" << sheetName;
continue;
}
markdown += u"## %1\n\n"_s.arg(sheetName);
// Determine the used range
QXlsx::CellRange range = sheet->dimension();
int firstRow = range.firstRow();
int lastRow = range.lastRow();
int firstCol = range.firstColumn();
int lastCol = range.lastColumn();
if (firstRow > lastRow || firstCol > lastCol) {
qWarning() << "Sheet" << sheetName << "is empty.";
markdown += "*No data available.*\n\n";
continue;
}
// Assume the first row is the header
int headerRow = firstRow;
// Collect headers
QStringList headers;
for (int col = firstCol; col <= lastCol; ++col) {
QString header = getCellValue(sheet, headerRow, col);
headers << header;
}
// Create Markdown header row
QString headerRowMarkdown = "|" + headers.join("|") + "|";
markdown += headerRowMarkdown + "\n";
// Create Markdown separator row
QStringList separators;
for (int i = 0; i < headers.size(); ++i)
separators << "---";
QString separatorRow = "|" + separators.join("|") + "|";
markdown += separatorRow + "\n";
// Iterate through data rows (starting from the row after header)
for (int row = headerRow + 1; row <= lastRow; ++row) {
QStringList rowData;
for (int col = firstCol; col <= lastCol; ++col) {
QString cellText = getCellValue(sheet, row, col);
rowData << cellText;
}
QString dataRow = "|" + rowData.join("|") + "|";
markdown += dataRow + "\n";
}
markdown += "\n"; // Add an empty line between sheets
}
return markdown;
}

View File

@@ -0,0 +1,13 @@
#ifndef XLSXTOMD_H
#define XLSXTOMD_H
class QIODevice;
class QString;
class XLSXToMD
{
public:
static QString toMarkdown(QIODevice *xlsxDevice);
};
#endif // XLSXTOMD_H