mirror of
				https://github.com/nomic-ai/gpt4all.git
				synced 2025-10-31 13:51:43 +00:00 
			
		
		
		
	Add save/restore to chatgpt chats and allow serialize/deseralize from disk.
This commit is contained in:
		| @@ -258,6 +258,7 @@ bool Chat::deserialize(QDataStream &stream, int version) | ||||
|     // unfortunately, we cannot deserialize these | ||||
|     if (version < 2 && m_savedModelName.contains("gpt4all-j")) | ||||
|         return false; | ||||
|     m_llmodel->setModelName(m_savedModelName); | ||||
|     if (!m_llmodel->deserialize(stream, version)) | ||||
|         return false; | ||||
|     if (!m_chatModel->deserialize(stream, version)) | ||||
|   | ||||
| @@ -46,6 +46,7 @@ bool ChatGPT::isModelLoaded() const | ||||
|     return true; | ||||
| } | ||||
|  | ||||
| // All three of the state virtual functions are handled custom inside of chatllm save/restore | ||||
| size_t ChatGPT::stateSize() const | ||||
| { | ||||
|     return 0; | ||||
| @@ -53,11 +54,13 @@ size_t ChatGPT::stateSize() const | ||||
|  | ||||
| size_t ChatGPT::saveState(uint8_t *dest) const | ||||
| { | ||||
|     Q_UNUSED(dest); | ||||
|     return 0; | ||||
| } | ||||
|  | ||||
| size_t ChatGPT::restoreState(const uint8_t *src) | ||||
| { | ||||
|     Q_UNUSED(src); | ||||
|     return 0; | ||||
| } | ||||
|  | ||||
| @@ -141,8 +144,8 @@ void ChatGPT::handleFinished() | ||||
|     bool ok; | ||||
|     int code = response.toInt(&ok); | ||||
|     if (!ok || code != 200) { | ||||
|         qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n") | ||||
|             .arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString(); | ||||
|         qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"") | ||||
|             .arg(code).arg(reply->errorString()).toStdString(); | ||||
|     } | ||||
|     reply->deleteLater(); | ||||
| } | ||||
| @@ -190,8 +193,11 @@ void ChatGPT::handleReadyRead() | ||||
|         const QString content = delta.value("content").toString(); | ||||
|         Q_ASSERT(m_ctx); | ||||
|         Q_ASSERT(m_responseCallback); | ||||
|         m_responseCallback(0, content.toStdString()); | ||||
|         m_currentResponse += content; | ||||
|         if (!m_responseCallback(0, content.toStdString())) { | ||||
|             reply->abort(); | ||||
|             return; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -201,6 +207,6 @@ void ChatGPT::handleErrorOccurred(QNetworkReply::NetworkError code) | ||||
|     if (!reply) | ||||
|         return; | ||||
|  | ||||
|     qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n") | ||||
|                       .arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString(); | ||||
|     qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"") | ||||
|                       .arg(code).arg(reply->errorString()).toStdString(); | ||||
| } | ||||
|   | ||||
| @@ -30,6 +30,9 @@ public: | ||||
|     void setModelName(const QString &modelName) { m_modelName = modelName; } | ||||
|     void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; } | ||||
|  | ||||
|     QList<QString> context() const { return m_context; } | ||||
|     void setContext(const QList<QString> &context) { m_context = context; } | ||||
|  | ||||
| protected: | ||||
|     void recalculateContext(PromptContext &promptCtx, | ||||
|         std::function<bool(bool)> recalculate) override {} | ||||
|   | ||||
| @@ -38,6 +38,19 @@ void ChatListModel::setShouldSaveChats(bool b) | ||||
|     emit shouldSaveChatsChanged(); | ||||
| } | ||||
|  | ||||
| bool ChatListModel::shouldSaveChatGPTChats() const | ||||
| { | ||||
|     return m_shouldSaveChatGPTChats; | ||||
| } | ||||
|  | ||||
| void ChatListModel::setShouldSaveChatGPTChats(bool b) | ||||
| { | ||||
|     if (m_shouldSaveChatGPTChats == b) | ||||
|         return; | ||||
|     m_shouldSaveChatGPTChats = b; | ||||
|     emit shouldSaveChatGPTChatsChanged(); | ||||
| } | ||||
|  | ||||
| void ChatListModel::removeChatFile(Chat *chat) const | ||||
| { | ||||
|     Q_ASSERT(chat != m_serverChat); | ||||
| @@ -52,15 +65,17 @@ void ChatListModel::removeChatFile(Chat *chat) const | ||||
|  | ||||
| void ChatListModel::saveChats() const | ||||
| { | ||||
|     if (!m_shouldSaveChats) | ||||
|         return; | ||||
|  | ||||
|     QElapsedTimer timer; | ||||
|     timer.start(); | ||||
|     const QString savePath = Download::globalInstance()->downloadLocalModelsPath(); | ||||
|     for (Chat *chat : m_chats) { | ||||
|         if (chat == m_serverChat) | ||||
|             continue; | ||||
|         const bool isChatGPT = chat->modelName().startsWith("chatgpt-"); | ||||
|         if (!isChatGPT && !m_shouldSaveChats) | ||||
|             continue; | ||||
|         if (isChatGPT && !m_shouldSaveChatGPTChats) | ||||
|             continue; | ||||
|         QString fileName = "gpt4all-" + chat->id() + ".chat"; | ||||
|         QFile file(savePath + "/" + fileName); | ||||
|         bool success = file.open(QIODevice::WriteOnly); | ||||
|   | ||||
| @@ -20,6 +20,7 @@ class ChatListModel : public QAbstractListModel | ||||
|     Q_PROPERTY(int count READ count NOTIFY countChanged) | ||||
|     Q_PROPERTY(Chat *currentChat READ currentChat WRITE setCurrentChat NOTIFY currentChatChanged) | ||||
|     Q_PROPERTY(bool shouldSaveChats READ shouldSaveChats WRITE setShouldSaveChats NOTIFY shouldSaveChatsChanged) | ||||
|     Q_PROPERTY(bool shouldSaveChatGPTChats READ shouldSaveChatGPTChats WRITE setShouldSaveChatGPTChats NOTIFY shouldSaveChatGPTChatsChanged) | ||||
|  | ||||
| public: | ||||
|     explicit ChatListModel(QObject *parent = nullptr); | ||||
| @@ -62,6 +63,9 @@ public: | ||||
|     bool shouldSaveChats() const; | ||||
|     void setShouldSaveChats(bool b); | ||||
|  | ||||
|     bool shouldSaveChatGPTChats() const; | ||||
|     void setShouldSaveChatGPTChats(bool b); | ||||
|  | ||||
|     Q_INVOKABLE void addChat() | ||||
|     { | ||||
|         // Don't add a new chat if we already have one | ||||
| @@ -199,6 +203,7 @@ Q_SIGNALS: | ||||
|     void countChanged(); | ||||
|     void currentChatChanged(); | ||||
|     void shouldSaveChatsChanged(); | ||||
|     void shouldSaveChatGPTChatsChanged(); | ||||
|  | ||||
| private Q_SLOTS: | ||||
|     void newChatCountChanged() | ||||
| @@ -240,6 +245,7 @@ private Q_SLOTS: | ||||
|  | ||||
| private: | ||||
|     bool m_shouldSaveChats; | ||||
|     bool m_shouldSaveChatGPTChats; | ||||
|     Chat* m_newChat; | ||||
|     Chat* m_dummyChat; | ||||
|     Chat* m_serverChat; | ||||
|   | ||||
| @@ -611,6 +611,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version) | ||||
|         stream >> compressed; | ||||
|         m_state = qUncompress(compressed); | ||||
|     } else { | ||||
|  | ||||
|         stream >> m_state; | ||||
|     } | ||||
| #if defined(DEBUG) | ||||
| @@ -624,6 +625,15 @@ void ChatLLM::saveState() | ||||
|     if (!isModelLoaded()) | ||||
|         return; | ||||
|  | ||||
|     if (m_isChatGPT) { | ||||
|         m_state.clear(); | ||||
|         QDataStream stream(&m_state, QIODeviceBase::WriteOnly); | ||||
|         stream.setVersion(QDataStream::Qt_6_5); | ||||
|         ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model); | ||||
|         stream << chatGPT->context(); | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     const size_t stateSize = m_modelInfo.model->stateSize(); | ||||
|     m_state.resize(stateSize); | ||||
| #if defined(DEBUG) | ||||
| @@ -637,6 +647,18 @@ void ChatLLM::restoreState() | ||||
|     if (!isModelLoaded() || m_state.isEmpty()) | ||||
|         return; | ||||
|  | ||||
|     if (m_isChatGPT) { | ||||
|         QDataStream stream(&m_state, QIODeviceBase::ReadOnly); | ||||
|         stream.setVersion(QDataStream::Qt_6_5); | ||||
|         ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model); | ||||
|         QList<QString> context; | ||||
|         stream >> context; | ||||
|         chatGPT->setContext(context); | ||||
|         m_state.clear(); | ||||
|         m_state.resize(0); | ||||
|         return; | ||||
|     } | ||||
|  | ||||
| #if defined(DEBUG) | ||||
|     qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size(); | ||||
| #endif | ||||
|   | ||||
| @@ -40,6 +40,7 @@ Dialog { | ||||
|     property int defaultRepeatPenaltyTokens: 64 | ||||
|     property int defaultThreadCount: 0 | ||||
|     property bool defaultSaveChats: false | ||||
|     property bool defaultSaveChatGPTChats: true | ||||
|     property bool defaultServerChat: false | ||||
|     property string defaultPromptTemplate: "### Human: | ||||
| %1 | ||||
| @@ -57,6 +58,7 @@ Dialog { | ||||
|     property alias repeatPenaltyTokens: settings.repeatPenaltyTokens | ||||
|     property alias threadCount: settings.threadCount | ||||
|     property alias saveChats: settings.saveChats | ||||
|     property alias saveChatGPTChats: settings.saveChatGPTChats | ||||
|     property alias serverChat: settings.serverChat | ||||
|     property alias modelPath: settings.modelPath | ||||
|     property alias userDefaultModel: settings.userDefaultModel | ||||
| @@ -70,6 +72,7 @@ Dialog { | ||||
|         property int promptBatchSize: settingsDialog.defaultPromptBatchSize | ||||
|         property int threadCount: settingsDialog.defaultThreadCount | ||||
|         property bool saveChats: settingsDialog.defaultSaveChats | ||||
|         property bool saveChatGPTChats: settingsDialog.defaultSaveChatGPTChats | ||||
|         property bool serverChat: settingsDialog.defaultServerChat | ||||
|         property real repeatPenalty: settingsDialog.defaultRepeatPenalty | ||||
|         property int repeatPenaltyTokens: settingsDialog.defaultRepeatPenaltyTokens | ||||
| @@ -94,12 +97,14 @@ Dialog { | ||||
|         settings.modelPath = settingsDialog.defaultModelPath | ||||
|         settings.threadCount = defaultThreadCount | ||||
|         settings.saveChats = defaultSaveChats | ||||
|         settings.saveChatGPTChats = defaultSaveChatGPTChats | ||||
|         settings.serverChat = defaultServerChat | ||||
|         settings.userDefaultModel = defaultUserDefaultModel | ||||
|         Download.downloadLocalModelsPath = settings.modelPath | ||||
|         LLM.threadCount = settings.threadCount | ||||
|         LLM.serverEnabled = settings.serverChat | ||||
|         LLM.chatListModel.shouldSaveChats = settings.saveChats | ||||
|         LLM.chatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats | ||||
|         settings.sync() | ||||
|     } | ||||
|  | ||||
| @@ -107,6 +112,7 @@ Dialog { | ||||
|         LLM.threadCount = settings.threadCount | ||||
|         LLM.serverEnabled = settings.serverChat | ||||
|         LLM.chatListModel.shouldSaveChats = settings.saveChats | ||||
|         LLM.chatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats | ||||
|         Download.downloadLocalModelsPath = settings.modelPath | ||||
|     } | ||||
|  | ||||
| @@ -803,16 +809,65 @@ Dialog { | ||||
|                         } | ||||
|                     } | ||||
|                     Label { | ||||
|                         id: serverChatLabel | ||||
|                         text: qsTr("Enable web server:") | ||||
|                         id: saveChatGPTChatsLabel | ||||
|                         text: qsTr("Save ChatGPT chats to disk:") | ||||
|                         color: theme.textColor | ||||
|                         Layout.row: 5 | ||||
|                         Layout.column: 0 | ||||
|                     } | ||||
|                     CheckBox { | ||||
|                         id: serverChatBox | ||||
|                         id: saveChatGPTChatsBox | ||||
|                         Layout.row: 5 | ||||
|                         Layout.column: 1 | ||||
|                         checked: settingsDialog.saveChatGPTChats | ||||
|                         onClicked: { | ||||
|                             settingsDialog.saveChatGPTChats = saveChatGPTChatsBox.checked | ||||
|                             LLM.chatListModel.shouldSaveChatGPTChats = saveChatGPTChatsBox.checked | ||||
|                             settings.sync() | ||||
|                         } | ||||
|  | ||||
|                         background: Rectangle { | ||||
|                             color: "transparent" | ||||
|                         } | ||||
|  | ||||
|                         indicator: Rectangle { | ||||
|                             implicitWidth: 26 | ||||
|                             implicitHeight: 26 | ||||
|                             x: saveChatGPTChatsBox.leftPadding | ||||
|                             y: parent.height / 2 - height / 2 | ||||
|                             border.color: theme.dialogBorder | ||||
|                             color: "transparent" | ||||
|  | ||||
|                             Rectangle { | ||||
|                                 width: 14 | ||||
|                                 height: 14 | ||||
|                                 x: 6 | ||||
|                                 y: 6 | ||||
|                                 color: theme.textColor | ||||
|                                 visible: saveChatGPTChatsBox.checked | ||||
|                             } | ||||
|                         } | ||||
|  | ||||
|                         contentItem: Text { | ||||
|                             text: saveChatGPTChatsBox.text | ||||
|                             font: saveChatGPTChatsBox.font | ||||
|                             opacity: enabled ? 1.0 : 0.3 | ||||
|                             color: theme.textColor | ||||
|                             verticalAlignment: Text.AlignVCenter | ||||
|                             leftPadding: saveChatGPTChatsBox.indicator.width + saveChatGPTChatsBox.spacing | ||||
|                         } | ||||
|                     } | ||||
|                     Label { | ||||
|                         id: serverChatLabel | ||||
|                         text: qsTr("Enable web server:") | ||||
|                         color: theme.textColor | ||||
|                         Layout.row: 6 | ||||
|                         Layout.column: 0 | ||||
|                     } | ||||
|                     CheckBox { | ||||
|                         id: serverChatBox | ||||
|                         Layout.row: 6 | ||||
|                         Layout.column: 1 | ||||
|                         checked: settings.serverChat | ||||
|                         onClicked: { | ||||
|                             settingsDialog.serverChat = serverChatBox.checked | ||||
| @@ -855,7 +910,7 @@ Dialog { | ||||
|                         } | ||||
|                     } | ||||
|                     Button { | ||||
|                         Layout.row: 6 | ||||
|                         Layout.row: 7 | ||||
|                         Layout.column: 1 | ||||
|                         Layout.fillWidth: true | ||||
|                         padding: 10 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user