Add save/restore to chatgpt chats and allow serialize/deseralize from disk.

This commit is contained in:
Adam Treat 2023-05-15 18:36:41 -04:00 committed by AT
parent 0cd509d530
commit f931de21c5
7 changed files with 120 additions and 12 deletions

View File

@ -258,6 +258,7 @@ bool Chat::deserialize(QDataStream &stream, int version)
// unfortunately, we cannot deserialize these // unfortunately, we cannot deserialize these
if (version < 2 && m_savedModelName.contains("gpt4all-j")) if (version < 2 && m_savedModelName.contains("gpt4all-j"))
return false; return false;
m_llmodel->setModelName(m_savedModelName);
if (!m_llmodel->deserialize(stream, version)) if (!m_llmodel->deserialize(stream, version))
return false; return false;
if (!m_chatModel->deserialize(stream, version)) if (!m_chatModel->deserialize(stream, version))

View File

@ -46,6 +46,7 @@ bool ChatGPT::isModelLoaded() const
return true; return true;
} }
// All three of the state virtual functions are handled custom inside of chatllm save/restore
size_t ChatGPT::stateSize() const size_t ChatGPT::stateSize() const
{ {
return 0; return 0;
@ -53,11 +54,13 @@ size_t ChatGPT::stateSize() const
size_t ChatGPT::saveState(uint8_t *dest) const size_t ChatGPT::saveState(uint8_t *dest) const
{ {
Q_UNUSED(dest);
return 0; return 0;
} }
size_t ChatGPT::restoreState(const uint8_t *src) size_t ChatGPT::restoreState(const uint8_t *src)
{ {
Q_UNUSED(src);
return 0; return 0;
} }
@ -141,8 +144,8 @@ void ChatGPT::handleFinished()
bool ok; bool ok;
int code = response.toInt(&ok); int code = response.toInt(&ok);
if (!ok || code != 200) { if (!ok || code != 200) {
qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n") qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
.arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString(); .arg(code).arg(reply->errorString()).toStdString();
} }
reply->deleteLater(); reply->deleteLater();
} }
@ -190,8 +193,11 @@ void ChatGPT::handleReadyRead()
const QString content = delta.value("content").toString(); const QString content = delta.value("content").toString();
Q_ASSERT(m_ctx); Q_ASSERT(m_ctx);
Q_ASSERT(m_responseCallback); Q_ASSERT(m_responseCallback);
m_responseCallback(0, content.toStdString());
m_currentResponse += content; m_currentResponse += content;
if (!m_responseCallback(0, content.toStdString())) {
reply->abort();
return;
}
} }
} }
@ -201,6 +207,6 @@ void ChatGPT::handleErrorOccurred(QNetworkReply::NetworkError code)
if (!reply) if (!reply)
return; return;
qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n") qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
.arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString(); .arg(code).arg(reply->errorString()).toStdString();
} }

View File

@ -30,6 +30,9 @@ public:
void setModelName(const QString &modelName) { m_modelName = modelName; } void setModelName(const QString &modelName) { m_modelName = modelName; }
void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; } void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; }
QList<QString> context() const { return m_context; }
void setContext(const QList<QString> &context) { m_context = context; }
protected: protected:
void recalculateContext(PromptContext &promptCtx, void recalculateContext(PromptContext &promptCtx,
std::function<bool(bool)> recalculate) override {} std::function<bool(bool)> recalculate) override {}

View File

@ -38,6 +38,19 @@ void ChatListModel::setShouldSaveChats(bool b)
emit shouldSaveChatsChanged(); emit shouldSaveChatsChanged();
} }
bool ChatListModel::shouldSaveChatGPTChats() const
{
return m_shouldSaveChatGPTChats;
}
void ChatListModel::setShouldSaveChatGPTChats(bool b)
{
if (m_shouldSaveChatGPTChats == b)
return;
m_shouldSaveChatGPTChats = b;
emit shouldSaveChatGPTChatsChanged();
}
void ChatListModel::removeChatFile(Chat *chat) const void ChatListModel::removeChatFile(Chat *chat) const
{ {
Q_ASSERT(chat != m_serverChat); Q_ASSERT(chat != m_serverChat);
@ -52,15 +65,17 @@ void ChatListModel::removeChatFile(Chat *chat) const
void ChatListModel::saveChats() const void ChatListModel::saveChats() const
{ {
if (!m_shouldSaveChats)
return;
QElapsedTimer timer; QElapsedTimer timer;
timer.start(); timer.start();
const QString savePath = Download::globalInstance()->downloadLocalModelsPath(); const QString savePath = Download::globalInstance()->downloadLocalModelsPath();
for (Chat *chat : m_chats) { for (Chat *chat : m_chats) {
if (chat == m_serverChat) if (chat == m_serverChat)
continue; continue;
const bool isChatGPT = chat->modelName().startsWith("chatgpt-");
if (!isChatGPT && !m_shouldSaveChats)
continue;
if (isChatGPT && !m_shouldSaveChatGPTChats)
continue;
QString fileName = "gpt4all-" + chat->id() + ".chat"; QString fileName = "gpt4all-" + chat->id() + ".chat";
QFile file(savePath + "/" + fileName); QFile file(savePath + "/" + fileName);
bool success = file.open(QIODevice::WriteOnly); bool success = file.open(QIODevice::WriteOnly);

View File

@ -20,6 +20,7 @@ class ChatListModel : public QAbstractListModel
Q_PROPERTY(int count READ count NOTIFY countChanged) Q_PROPERTY(int count READ count NOTIFY countChanged)
Q_PROPERTY(Chat *currentChat READ currentChat WRITE setCurrentChat NOTIFY currentChatChanged) Q_PROPERTY(Chat *currentChat READ currentChat WRITE setCurrentChat NOTIFY currentChatChanged)
Q_PROPERTY(bool shouldSaveChats READ shouldSaveChats WRITE setShouldSaveChats NOTIFY shouldSaveChatsChanged) Q_PROPERTY(bool shouldSaveChats READ shouldSaveChats WRITE setShouldSaveChats NOTIFY shouldSaveChatsChanged)
Q_PROPERTY(bool shouldSaveChatGPTChats READ shouldSaveChatGPTChats WRITE setShouldSaveChatGPTChats NOTIFY shouldSaveChatGPTChatsChanged)
public: public:
explicit ChatListModel(QObject *parent = nullptr); explicit ChatListModel(QObject *parent = nullptr);
@ -62,6 +63,9 @@ public:
bool shouldSaveChats() const; bool shouldSaveChats() const;
void setShouldSaveChats(bool b); void setShouldSaveChats(bool b);
bool shouldSaveChatGPTChats() const;
void setShouldSaveChatGPTChats(bool b);
Q_INVOKABLE void addChat() Q_INVOKABLE void addChat()
{ {
// Don't add a new chat if we already have one // Don't add a new chat if we already have one
@ -199,6 +203,7 @@ Q_SIGNALS:
void countChanged(); void countChanged();
void currentChatChanged(); void currentChatChanged();
void shouldSaveChatsChanged(); void shouldSaveChatsChanged();
void shouldSaveChatGPTChatsChanged();
private Q_SLOTS: private Q_SLOTS:
void newChatCountChanged() void newChatCountChanged()
@ -240,6 +245,7 @@ private Q_SLOTS:
private: private:
bool m_shouldSaveChats; bool m_shouldSaveChats;
bool m_shouldSaveChatGPTChats;
Chat* m_newChat; Chat* m_newChat;
Chat* m_dummyChat; Chat* m_dummyChat;
Chat* m_serverChat; Chat* m_serverChat;

View File

@ -611,6 +611,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
stream >> compressed; stream >> compressed;
m_state = qUncompress(compressed); m_state = qUncompress(compressed);
} else { } else {
stream >> m_state; stream >> m_state;
} }
#if defined(DEBUG) #if defined(DEBUG)
@ -624,6 +625,15 @@ void ChatLLM::saveState()
if (!isModelLoaded()) if (!isModelLoaded())
return; return;
if (m_isChatGPT) {
m_state.clear();
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
stream.setVersion(QDataStream::Qt_6_5);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model);
stream << chatGPT->context();
return;
}
const size_t stateSize = m_modelInfo.model->stateSize(); const size_t stateSize = m_modelInfo.model->stateSize();
m_state.resize(stateSize); m_state.resize(stateSize);
#if defined(DEBUG) #if defined(DEBUG)
@ -637,6 +647,18 @@ void ChatLLM::restoreState()
if (!isModelLoaded() || m_state.isEmpty()) if (!isModelLoaded() || m_state.isEmpty())
return; return;
if (m_isChatGPT) {
QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
stream.setVersion(QDataStream::Qt_6_5);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model);
QList<QString> context;
stream >> context;
chatGPT->setContext(context);
m_state.clear();
m_state.resize(0);
return;
}
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size(); qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size();
#endif #endif

View File

@ -40,6 +40,7 @@ Dialog {
property int defaultRepeatPenaltyTokens: 64 property int defaultRepeatPenaltyTokens: 64
property int defaultThreadCount: 0 property int defaultThreadCount: 0
property bool defaultSaveChats: false property bool defaultSaveChats: false
property bool defaultSaveChatGPTChats: true
property bool defaultServerChat: false property bool defaultServerChat: false
property string defaultPromptTemplate: "### Human: property string defaultPromptTemplate: "### Human:
%1 %1
@ -57,6 +58,7 @@ Dialog {
property alias repeatPenaltyTokens: settings.repeatPenaltyTokens property alias repeatPenaltyTokens: settings.repeatPenaltyTokens
property alias threadCount: settings.threadCount property alias threadCount: settings.threadCount
property alias saveChats: settings.saveChats property alias saveChats: settings.saveChats
property alias saveChatGPTChats: settings.saveChatGPTChats
property alias serverChat: settings.serverChat property alias serverChat: settings.serverChat
property alias modelPath: settings.modelPath property alias modelPath: settings.modelPath
property alias userDefaultModel: settings.userDefaultModel property alias userDefaultModel: settings.userDefaultModel
@ -70,6 +72,7 @@ Dialog {
property int promptBatchSize: settingsDialog.defaultPromptBatchSize property int promptBatchSize: settingsDialog.defaultPromptBatchSize
property int threadCount: settingsDialog.defaultThreadCount property int threadCount: settingsDialog.defaultThreadCount
property bool saveChats: settingsDialog.defaultSaveChats property bool saveChats: settingsDialog.defaultSaveChats
property bool saveChatGPTChats: settingsDialog.defaultSaveChatGPTChats
property bool serverChat: settingsDialog.defaultServerChat property bool serverChat: settingsDialog.defaultServerChat
property real repeatPenalty: settingsDialog.defaultRepeatPenalty property real repeatPenalty: settingsDialog.defaultRepeatPenalty
property int repeatPenaltyTokens: settingsDialog.defaultRepeatPenaltyTokens property int repeatPenaltyTokens: settingsDialog.defaultRepeatPenaltyTokens
@ -94,12 +97,14 @@ Dialog {
settings.modelPath = settingsDialog.defaultModelPath settings.modelPath = settingsDialog.defaultModelPath
settings.threadCount = defaultThreadCount settings.threadCount = defaultThreadCount
settings.saveChats = defaultSaveChats settings.saveChats = defaultSaveChats
settings.saveChatGPTChats = defaultSaveChatGPTChats
settings.serverChat = defaultServerChat settings.serverChat = defaultServerChat
settings.userDefaultModel = defaultUserDefaultModel settings.userDefaultModel = defaultUserDefaultModel
Download.downloadLocalModelsPath = settings.modelPath Download.downloadLocalModelsPath = settings.modelPath
LLM.threadCount = settings.threadCount LLM.threadCount = settings.threadCount
LLM.serverEnabled = settings.serverChat LLM.serverEnabled = settings.serverChat
LLM.chatListModel.shouldSaveChats = settings.saveChats LLM.chatListModel.shouldSaveChats = settings.saveChats
LLM.chatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats
settings.sync() settings.sync()
} }
@ -107,6 +112,7 @@ Dialog {
LLM.threadCount = settings.threadCount LLM.threadCount = settings.threadCount
LLM.serverEnabled = settings.serverChat LLM.serverEnabled = settings.serverChat
LLM.chatListModel.shouldSaveChats = settings.saveChats LLM.chatListModel.shouldSaveChats = settings.saveChats
LLM.chatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats
Download.downloadLocalModelsPath = settings.modelPath Download.downloadLocalModelsPath = settings.modelPath
} }
@ -803,16 +809,65 @@ Dialog {
} }
} }
Label { Label {
id: serverChatLabel id: saveChatGPTChatsLabel
text: qsTr("Enable web server:") text: qsTr("Save ChatGPT chats to disk:")
color: theme.textColor color: theme.textColor
Layout.row: 5 Layout.row: 5
Layout.column: 0 Layout.column: 0
} }
CheckBox { CheckBox {
id: serverChatBox id: saveChatGPTChatsBox
Layout.row: 5 Layout.row: 5
Layout.column: 1 Layout.column: 1
checked: settingsDialog.saveChatGPTChats
onClicked: {
settingsDialog.saveChatGPTChats = saveChatGPTChatsBox.checked
LLM.chatListModel.shouldSaveChatGPTChats = saveChatGPTChatsBox.checked
settings.sync()
}
background: Rectangle {
color: "transparent"
}
indicator: Rectangle {
implicitWidth: 26
implicitHeight: 26
x: saveChatGPTChatsBox.leftPadding
y: parent.height / 2 - height / 2
border.color: theme.dialogBorder
color: "transparent"
Rectangle {
width: 14
height: 14
x: 6
y: 6
color: theme.textColor
visible: saveChatGPTChatsBox.checked
}
}
contentItem: Text {
text: saveChatGPTChatsBox.text
font: saveChatGPTChatsBox.font
opacity: enabled ? 1.0 : 0.3
color: theme.textColor
verticalAlignment: Text.AlignVCenter
leftPadding: saveChatGPTChatsBox.indicator.width + saveChatGPTChatsBox.spacing
}
}
Label {
id: serverChatLabel
text: qsTr("Enable web server:")
color: theme.textColor
Layout.row: 6
Layout.column: 0
}
CheckBox {
id: serverChatBox
Layout.row: 6
Layout.column: 1
checked: settings.serverChat checked: settings.serverChat
onClicked: { onClicked: {
settingsDialog.serverChat = serverChatBox.checked settingsDialog.serverChat = serverChatBox.checked
@ -855,7 +910,7 @@ Dialog {
} }
} }
Button { Button {
Layout.row: 6 Layout.row: 7
Layout.column: 1 Layout.column: 1
Layout.fillWidth: true Layout.fillWidth: true
padding: 10 padding: 10