mirror of
				https://github.com/nomic-ai/gpt4all.git
				synced 2025-10-30 21:30:42 +00:00 
			
		
		
		
	Add save/restore to chatgpt chats and allow serialize/deseralize from disk.
This commit is contained in:
		| @@ -258,6 +258,7 @@ bool Chat::deserialize(QDataStream &stream, int version) | |||||||
|     // unfortunately, we cannot deserialize these |     // unfortunately, we cannot deserialize these | ||||||
|     if (version < 2 && m_savedModelName.contains("gpt4all-j")) |     if (version < 2 && m_savedModelName.contains("gpt4all-j")) | ||||||
|         return false; |         return false; | ||||||
|  |     m_llmodel->setModelName(m_savedModelName); | ||||||
|     if (!m_llmodel->deserialize(stream, version)) |     if (!m_llmodel->deserialize(stream, version)) | ||||||
|         return false; |         return false; | ||||||
|     if (!m_chatModel->deserialize(stream, version)) |     if (!m_chatModel->deserialize(stream, version)) | ||||||
|   | |||||||
| @@ -46,6 +46,7 @@ bool ChatGPT::isModelLoaded() const | |||||||
|     return true; |     return true; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // All three of the state virtual functions are handled custom inside of chatllm save/restore | ||||||
| size_t ChatGPT::stateSize() const | size_t ChatGPT::stateSize() const | ||||||
| { | { | ||||||
|     return 0; |     return 0; | ||||||
| @@ -53,11 +54,13 @@ size_t ChatGPT::stateSize() const | |||||||
|  |  | ||||||
| size_t ChatGPT::saveState(uint8_t *dest) const | size_t ChatGPT::saveState(uint8_t *dest) const | ||||||
| { | { | ||||||
|  |     Q_UNUSED(dest); | ||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
|  |  | ||||||
| size_t ChatGPT::restoreState(const uint8_t *src) | size_t ChatGPT::restoreState(const uint8_t *src) | ||||||
| { | { | ||||||
|  |     Q_UNUSED(src); | ||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -141,8 +144,8 @@ void ChatGPT::handleFinished() | |||||||
|     bool ok; |     bool ok; | ||||||
|     int code = response.toInt(&ok); |     int code = response.toInt(&ok); | ||||||
|     if (!ok || code != 200) { |     if (!ok || code != 200) { | ||||||
|         qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n") |         qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"") | ||||||
|             .arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString(); |             .arg(code).arg(reply->errorString()).toStdString(); | ||||||
|     } |     } | ||||||
|     reply->deleteLater(); |     reply->deleteLater(); | ||||||
| } | } | ||||||
| @@ -190,8 +193,11 @@ void ChatGPT::handleReadyRead() | |||||||
|         const QString content = delta.value("content").toString(); |         const QString content = delta.value("content").toString(); | ||||||
|         Q_ASSERT(m_ctx); |         Q_ASSERT(m_ctx); | ||||||
|         Q_ASSERT(m_responseCallback); |         Q_ASSERT(m_responseCallback); | ||||||
|         m_responseCallback(0, content.toStdString()); |  | ||||||
|         m_currentResponse += content; |         m_currentResponse += content; | ||||||
|  |         if (!m_responseCallback(0, content.toStdString())) { | ||||||
|  |             reply->abort(); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -201,6 +207,6 @@ void ChatGPT::handleErrorOccurred(QNetworkReply::NetworkError code) | |||||||
|     if (!reply) |     if (!reply) | ||||||
|         return; |         return; | ||||||
|  |  | ||||||
|     qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n") |     qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"") | ||||||
|                       .arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString(); |                       .arg(code).arg(reply->errorString()).toStdString(); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -30,6 +30,9 @@ public: | |||||||
|     void setModelName(const QString &modelName) { m_modelName = modelName; } |     void setModelName(const QString &modelName) { m_modelName = modelName; } | ||||||
|     void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; } |     void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; } | ||||||
|  |  | ||||||
|  |     QList<QString> context() const { return m_context; } | ||||||
|  |     void setContext(const QList<QString> &context) { m_context = context; } | ||||||
|  |  | ||||||
| protected: | protected: | ||||||
|     void recalculateContext(PromptContext &promptCtx, |     void recalculateContext(PromptContext &promptCtx, | ||||||
|         std::function<bool(bool)> recalculate) override {} |         std::function<bool(bool)> recalculate) override {} | ||||||
|   | |||||||
| @@ -38,6 +38,19 @@ void ChatListModel::setShouldSaveChats(bool b) | |||||||
|     emit shouldSaveChatsChanged(); |     emit shouldSaveChatsChanged(); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | bool ChatListModel::shouldSaveChatGPTChats() const | ||||||
|  | { | ||||||
|  |     return m_shouldSaveChatGPTChats; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void ChatListModel::setShouldSaveChatGPTChats(bool b) | ||||||
|  | { | ||||||
|  |     if (m_shouldSaveChatGPTChats == b) | ||||||
|  |         return; | ||||||
|  |     m_shouldSaveChatGPTChats = b; | ||||||
|  |     emit shouldSaveChatGPTChatsChanged(); | ||||||
|  | } | ||||||
|  |  | ||||||
| void ChatListModel::removeChatFile(Chat *chat) const | void ChatListModel::removeChatFile(Chat *chat) const | ||||||
| { | { | ||||||
|     Q_ASSERT(chat != m_serverChat); |     Q_ASSERT(chat != m_serverChat); | ||||||
| @@ -52,15 +65,17 @@ void ChatListModel::removeChatFile(Chat *chat) const | |||||||
|  |  | ||||||
| void ChatListModel::saveChats() const | void ChatListModel::saveChats() const | ||||||
| { | { | ||||||
|     if (!m_shouldSaveChats) |  | ||||||
|         return; |  | ||||||
|  |  | ||||||
|     QElapsedTimer timer; |     QElapsedTimer timer; | ||||||
|     timer.start(); |     timer.start(); | ||||||
|     const QString savePath = Download::globalInstance()->downloadLocalModelsPath(); |     const QString savePath = Download::globalInstance()->downloadLocalModelsPath(); | ||||||
|     for (Chat *chat : m_chats) { |     for (Chat *chat : m_chats) { | ||||||
|         if (chat == m_serverChat) |         if (chat == m_serverChat) | ||||||
|             continue; |             continue; | ||||||
|  |         const bool isChatGPT = chat->modelName().startsWith("chatgpt-"); | ||||||
|  |         if (!isChatGPT && !m_shouldSaveChats) | ||||||
|  |             continue; | ||||||
|  |         if (isChatGPT && !m_shouldSaveChatGPTChats) | ||||||
|  |             continue; | ||||||
|         QString fileName = "gpt4all-" + chat->id() + ".chat"; |         QString fileName = "gpt4all-" + chat->id() + ".chat"; | ||||||
|         QFile file(savePath + "/" + fileName); |         QFile file(savePath + "/" + fileName); | ||||||
|         bool success = file.open(QIODevice::WriteOnly); |         bool success = file.open(QIODevice::WriteOnly); | ||||||
|   | |||||||
| @@ -20,6 +20,7 @@ class ChatListModel : public QAbstractListModel | |||||||
|     Q_PROPERTY(int count READ count NOTIFY countChanged) |     Q_PROPERTY(int count READ count NOTIFY countChanged) | ||||||
|     Q_PROPERTY(Chat *currentChat READ currentChat WRITE setCurrentChat NOTIFY currentChatChanged) |     Q_PROPERTY(Chat *currentChat READ currentChat WRITE setCurrentChat NOTIFY currentChatChanged) | ||||||
|     Q_PROPERTY(bool shouldSaveChats READ shouldSaveChats WRITE setShouldSaveChats NOTIFY shouldSaveChatsChanged) |     Q_PROPERTY(bool shouldSaveChats READ shouldSaveChats WRITE setShouldSaveChats NOTIFY shouldSaveChatsChanged) | ||||||
|  |     Q_PROPERTY(bool shouldSaveChatGPTChats READ shouldSaveChatGPTChats WRITE setShouldSaveChatGPTChats NOTIFY shouldSaveChatGPTChatsChanged) | ||||||
|  |  | ||||||
| public: | public: | ||||||
|     explicit ChatListModel(QObject *parent = nullptr); |     explicit ChatListModel(QObject *parent = nullptr); | ||||||
| @@ -62,6 +63,9 @@ public: | |||||||
|     bool shouldSaveChats() const; |     bool shouldSaveChats() const; | ||||||
|     void setShouldSaveChats(bool b); |     void setShouldSaveChats(bool b); | ||||||
|  |  | ||||||
|  |     bool shouldSaveChatGPTChats() const; | ||||||
|  |     void setShouldSaveChatGPTChats(bool b); | ||||||
|  |  | ||||||
|     Q_INVOKABLE void addChat() |     Q_INVOKABLE void addChat() | ||||||
|     { |     { | ||||||
|         // Don't add a new chat if we already have one |         // Don't add a new chat if we already have one | ||||||
| @@ -199,6 +203,7 @@ Q_SIGNALS: | |||||||
|     void countChanged(); |     void countChanged(); | ||||||
|     void currentChatChanged(); |     void currentChatChanged(); | ||||||
|     void shouldSaveChatsChanged(); |     void shouldSaveChatsChanged(); | ||||||
|  |     void shouldSaveChatGPTChatsChanged(); | ||||||
|  |  | ||||||
| private Q_SLOTS: | private Q_SLOTS: | ||||||
|     void newChatCountChanged() |     void newChatCountChanged() | ||||||
| @@ -240,6 +245,7 @@ private Q_SLOTS: | |||||||
|  |  | ||||||
| private: | private: | ||||||
|     bool m_shouldSaveChats; |     bool m_shouldSaveChats; | ||||||
|  |     bool m_shouldSaveChatGPTChats; | ||||||
|     Chat* m_newChat; |     Chat* m_newChat; | ||||||
|     Chat* m_dummyChat; |     Chat* m_dummyChat; | ||||||
|     Chat* m_serverChat; |     Chat* m_serverChat; | ||||||
|   | |||||||
| @@ -611,6 +611,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version) | |||||||
|         stream >> compressed; |         stream >> compressed; | ||||||
|         m_state = qUncompress(compressed); |         m_state = qUncompress(compressed); | ||||||
|     } else { |     } else { | ||||||
|  |  | ||||||
|         stream >> m_state; |         stream >> m_state; | ||||||
|     } |     } | ||||||
| #if defined(DEBUG) | #if defined(DEBUG) | ||||||
| @@ -624,6 +625,15 @@ void ChatLLM::saveState() | |||||||
|     if (!isModelLoaded()) |     if (!isModelLoaded()) | ||||||
|         return; |         return; | ||||||
|  |  | ||||||
|  |     if (m_isChatGPT) { | ||||||
|  |         m_state.clear(); | ||||||
|  |         QDataStream stream(&m_state, QIODeviceBase::WriteOnly); | ||||||
|  |         stream.setVersion(QDataStream::Qt_6_5); | ||||||
|  |         ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model); | ||||||
|  |         stream << chatGPT->context(); | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     const size_t stateSize = m_modelInfo.model->stateSize(); |     const size_t stateSize = m_modelInfo.model->stateSize(); | ||||||
|     m_state.resize(stateSize); |     m_state.resize(stateSize); | ||||||
| #if defined(DEBUG) | #if defined(DEBUG) | ||||||
| @@ -637,6 +647,18 @@ void ChatLLM::restoreState() | |||||||
|     if (!isModelLoaded() || m_state.isEmpty()) |     if (!isModelLoaded() || m_state.isEmpty()) | ||||||
|         return; |         return; | ||||||
|  |  | ||||||
|  |     if (m_isChatGPT) { | ||||||
|  |         QDataStream stream(&m_state, QIODeviceBase::ReadOnly); | ||||||
|  |         stream.setVersion(QDataStream::Qt_6_5); | ||||||
|  |         ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model); | ||||||
|  |         QList<QString> context; | ||||||
|  |         stream >> context; | ||||||
|  |         chatGPT->setContext(context); | ||||||
|  |         m_state.clear(); | ||||||
|  |         m_state.resize(0); | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
| #if defined(DEBUG) | #if defined(DEBUG) | ||||||
|     qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size(); |     qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size(); | ||||||
| #endif | #endif | ||||||
|   | |||||||
| @@ -40,6 +40,7 @@ Dialog { | |||||||
|     property int defaultRepeatPenaltyTokens: 64 |     property int defaultRepeatPenaltyTokens: 64 | ||||||
|     property int defaultThreadCount: 0 |     property int defaultThreadCount: 0 | ||||||
|     property bool defaultSaveChats: false |     property bool defaultSaveChats: false | ||||||
|  |     property bool defaultSaveChatGPTChats: true | ||||||
|     property bool defaultServerChat: false |     property bool defaultServerChat: false | ||||||
|     property string defaultPromptTemplate: "### Human: |     property string defaultPromptTemplate: "### Human: | ||||||
| %1 | %1 | ||||||
| @@ -57,6 +58,7 @@ Dialog { | |||||||
|     property alias repeatPenaltyTokens: settings.repeatPenaltyTokens |     property alias repeatPenaltyTokens: settings.repeatPenaltyTokens | ||||||
|     property alias threadCount: settings.threadCount |     property alias threadCount: settings.threadCount | ||||||
|     property alias saveChats: settings.saveChats |     property alias saveChats: settings.saveChats | ||||||
|  |     property alias saveChatGPTChats: settings.saveChatGPTChats | ||||||
|     property alias serverChat: settings.serverChat |     property alias serverChat: settings.serverChat | ||||||
|     property alias modelPath: settings.modelPath |     property alias modelPath: settings.modelPath | ||||||
|     property alias userDefaultModel: settings.userDefaultModel |     property alias userDefaultModel: settings.userDefaultModel | ||||||
| @@ -70,6 +72,7 @@ Dialog { | |||||||
|         property int promptBatchSize: settingsDialog.defaultPromptBatchSize |         property int promptBatchSize: settingsDialog.defaultPromptBatchSize | ||||||
|         property int threadCount: settingsDialog.defaultThreadCount |         property int threadCount: settingsDialog.defaultThreadCount | ||||||
|         property bool saveChats: settingsDialog.defaultSaveChats |         property bool saveChats: settingsDialog.defaultSaveChats | ||||||
|  |         property bool saveChatGPTChats: settingsDialog.defaultSaveChatGPTChats | ||||||
|         property bool serverChat: settingsDialog.defaultServerChat |         property bool serverChat: settingsDialog.defaultServerChat | ||||||
|         property real repeatPenalty: settingsDialog.defaultRepeatPenalty |         property real repeatPenalty: settingsDialog.defaultRepeatPenalty | ||||||
|         property int repeatPenaltyTokens: settingsDialog.defaultRepeatPenaltyTokens |         property int repeatPenaltyTokens: settingsDialog.defaultRepeatPenaltyTokens | ||||||
| @@ -94,12 +97,14 @@ Dialog { | |||||||
|         settings.modelPath = settingsDialog.defaultModelPath |         settings.modelPath = settingsDialog.defaultModelPath | ||||||
|         settings.threadCount = defaultThreadCount |         settings.threadCount = defaultThreadCount | ||||||
|         settings.saveChats = defaultSaveChats |         settings.saveChats = defaultSaveChats | ||||||
|  |         settings.saveChatGPTChats = defaultSaveChatGPTChats | ||||||
|         settings.serverChat = defaultServerChat |         settings.serverChat = defaultServerChat | ||||||
|         settings.userDefaultModel = defaultUserDefaultModel |         settings.userDefaultModel = defaultUserDefaultModel | ||||||
|         Download.downloadLocalModelsPath = settings.modelPath |         Download.downloadLocalModelsPath = settings.modelPath | ||||||
|         LLM.threadCount = settings.threadCount |         LLM.threadCount = settings.threadCount | ||||||
|         LLM.serverEnabled = settings.serverChat |         LLM.serverEnabled = settings.serverChat | ||||||
|         LLM.chatListModel.shouldSaveChats = settings.saveChats |         LLM.chatListModel.shouldSaveChats = settings.saveChats | ||||||
|  |         LLM.chatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats | ||||||
|         settings.sync() |         settings.sync() | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -107,6 +112,7 @@ Dialog { | |||||||
|         LLM.threadCount = settings.threadCount |         LLM.threadCount = settings.threadCount | ||||||
|         LLM.serverEnabled = settings.serverChat |         LLM.serverEnabled = settings.serverChat | ||||||
|         LLM.chatListModel.shouldSaveChats = settings.saveChats |         LLM.chatListModel.shouldSaveChats = settings.saveChats | ||||||
|  |         LLM.chatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats | ||||||
|         Download.downloadLocalModelsPath = settings.modelPath |         Download.downloadLocalModelsPath = settings.modelPath | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -803,16 +809,65 @@ Dialog { | |||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                     Label { |                     Label { | ||||||
|                         id: serverChatLabel |                         id: saveChatGPTChatsLabel | ||||||
|                         text: qsTr("Enable web server:") |                         text: qsTr("Save ChatGPT chats to disk:") | ||||||
|                         color: theme.textColor |                         color: theme.textColor | ||||||
|                         Layout.row: 5 |                         Layout.row: 5 | ||||||
|                         Layout.column: 0 |                         Layout.column: 0 | ||||||
|                     } |                     } | ||||||
|                     CheckBox { |                     CheckBox { | ||||||
|                         id: serverChatBox |                         id: saveChatGPTChatsBox | ||||||
|                         Layout.row: 5 |                         Layout.row: 5 | ||||||
|                         Layout.column: 1 |                         Layout.column: 1 | ||||||
|  |                         checked: settingsDialog.saveChatGPTChats | ||||||
|  |                         onClicked: { | ||||||
|  |                             settingsDialog.saveChatGPTChats = saveChatGPTChatsBox.checked | ||||||
|  |                             LLM.chatListModel.shouldSaveChatGPTChats = saveChatGPTChatsBox.checked | ||||||
|  |                             settings.sync() | ||||||
|  |                         } | ||||||
|  |  | ||||||
|  |                         background: Rectangle { | ||||||
|  |                             color: "transparent" | ||||||
|  |                         } | ||||||
|  |  | ||||||
|  |                         indicator: Rectangle { | ||||||
|  |                             implicitWidth: 26 | ||||||
|  |                             implicitHeight: 26 | ||||||
|  |                             x: saveChatGPTChatsBox.leftPadding | ||||||
|  |                             y: parent.height / 2 - height / 2 | ||||||
|  |                             border.color: theme.dialogBorder | ||||||
|  |                             color: "transparent" | ||||||
|  |  | ||||||
|  |                             Rectangle { | ||||||
|  |                                 width: 14 | ||||||
|  |                                 height: 14 | ||||||
|  |                                 x: 6 | ||||||
|  |                                 y: 6 | ||||||
|  |                                 color: theme.textColor | ||||||
|  |                                 visible: saveChatGPTChatsBox.checked | ||||||
|  |                             } | ||||||
|  |                         } | ||||||
|  |  | ||||||
|  |                         contentItem: Text { | ||||||
|  |                             text: saveChatGPTChatsBox.text | ||||||
|  |                             font: saveChatGPTChatsBox.font | ||||||
|  |                             opacity: enabled ? 1.0 : 0.3 | ||||||
|  |                             color: theme.textColor | ||||||
|  |                             verticalAlignment: Text.AlignVCenter | ||||||
|  |                             leftPadding: saveChatGPTChatsBox.indicator.width + saveChatGPTChatsBox.spacing | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                     Label { | ||||||
|  |                         id: serverChatLabel | ||||||
|  |                         text: qsTr("Enable web server:") | ||||||
|  |                         color: theme.textColor | ||||||
|  |                         Layout.row: 6 | ||||||
|  |                         Layout.column: 0 | ||||||
|  |                     } | ||||||
|  |                     CheckBox { | ||||||
|  |                         id: serverChatBox | ||||||
|  |                         Layout.row: 6 | ||||||
|  |                         Layout.column: 1 | ||||||
|                         checked: settings.serverChat |                         checked: settings.serverChat | ||||||
|                         onClicked: { |                         onClicked: { | ||||||
|                             settingsDialog.serverChat = serverChatBox.checked |                             settingsDialog.serverChat = serverChatBox.checked | ||||||
| @@ -855,7 +910,7 @@ Dialog { | |||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                     Button { |                     Button { | ||||||
|                         Layout.row: 6 |                         Layout.row: 7 | ||||||
|                         Layout.column: 1 |                         Layout.column: 1 | ||||||
|                         Layout.fillWidth: true |                         Layout.fillWidth: true | ||||||
|                         padding: 10 |                         padding: 10 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user