diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index bcb96b60..9c68c823 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -61,6 +61,7 @@ qt_add_executable(chat chat.h chat.cpp chatllm.h chatllm.cpp chatmodel.h chatlistmodel.h chatlistmodel.cpp + chatgpt.h chatgpt.cpp download.h download.cpp network.h network.cpp llm.h llm.cpp diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 6cf637d2..0a4154be 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -304,12 +304,13 @@ QList Chat::modelList() const if (localPath != exePath) { QDir dir(localPath); - dir.setNameFilters(QStringList() << "ggml-*.bin"); + dir.setNameFilters(QStringList() << "ggml-*.bin" << "chatgpt-*.txt"); QStringList fileNames = dir.entryList(); for (QString f : fileNames) { QString filePath = localPath + f; QFileInfo info(filePath); - QString name = info.completeBaseName().remove(0, 5); + QString basename = info.completeBaseName(); + QString name = basename.startsWith("ggml-") ? basename.remove(0, 5) : basename; if (info.exists() && !list.contains(name)) { // don't allow duplicates if (name == currentModelName) list.prepend(name); diff --git a/gpt4all-chat/chatgpt.cpp b/gpt4all-chat/chatgpt.cpp new file mode 100644 index 00000000..0e2f3ab2 --- /dev/null +++ b/gpt4all-chat/chatgpt.cpp @@ -0,0 +1,206 @@ +#include "chatgpt.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +//#define DEBUG + +ChatGPT::ChatGPT() + : QObject(nullptr) + , m_modelName("gpt-3.5-turbo") + , m_ctx(nullptr) + , m_responseCallback(nullptr) +{ +} + +bool ChatGPT::loadModel(const std::string &modelPath) +{ + Q_UNUSED(modelPath); + return true; +} + +void ChatGPT::setThreadCount(int32_t n_threads) +{ + Q_UNUSED(n_threads); + qt_noop(); +} + +int32_t ChatGPT::threadCount() +{ + return 1; +} + +ChatGPT::~ChatGPT() +{ +} + +bool ChatGPT::isModelLoaded() const +{ + return true; +} + +size_t ChatGPT::stateSize() const +{ + return 0; +} + +size_t ChatGPT::saveState(uint8_t *dest) const +{ + return 0; +} + +size_t ChatGPT::restoreState(const uint8_t *src) +{ + return 0; +} + +void ChatGPT::prompt(const std::string &prompt, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, + PromptContext &promptCtx) { + + Q_UNUSED(promptCallback); + Q_UNUSED(recalculateCallback); + + if (!isModelLoaded()) { + std::cerr << "ChatGPT ERROR: prompt won't work with an unloaded model!\n"; + return; + } + + m_ctx = &promptCtx; + m_responseCallback = responseCallback; + + QJsonObject root; + root.insert("model", m_modelName); + root.insert("stream", true); + root.insert("temperature", promptCtx.temp); + root.insert("top_p", promptCtx.top_p); + root.insert("max_tokens", 200); + + QJsonArray messages; + for (int i = 0; i < m_context.count() && i < promptCtx.n_past; ++i) { + QJsonObject message; + message.insert("role", i % 2 == 0 ? "assistant" : "user"); + message.insert("content", m_context.at(i)); + messages.append(message); + } + + QJsonObject promptObject; + promptObject.insert("role", "user"); + promptObject.insert("content", QString::fromStdString(prompt)); + messages.append(promptObject); + root.insert("messages", messages); + + QJsonDocument doc(root); + +#if defined(DEBUG) + qDebug() << "ChatGPT::prompt begin network request" << qPrintable(doc.toJson()); +#endif + + QEventLoop loop; + QUrl openaiUrl("https://api.openai.com/v1/chat/completions"); + const QString authorization = QString("Bearer %1").arg(m_apiKey); + QNetworkRequest request(openaiUrl); + request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + request.setRawHeader("Authorization", authorization.toUtf8()); + QNetworkReply *reply = m_networkManager.post(request, doc.toJson(QJsonDocument::Compact)); + connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit); + connect(reply, &QNetworkReply::finished, this, &ChatGPT::handleFinished); + connect(reply, &QNetworkReply::readyRead, this, &ChatGPT::handleReadyRead); + connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPT::handleErrorOccurred); + loop.exec(); +#if defined(DEBUG) + qDebug() << "ChatGPT::prompt end network request"; +#endif + + m_ctx->n_past += 1; + m_context.append(QString::fromStdString(prompt)); + m_context.append(m_currentResponse); + + m_ctx = nullptr; + m_responseCallback = nullptr; + m_currentResponse = QString(); +} + +void ChatGPT::handleFinished() +{ + QNetworkReply *reply = qobject_cast(sender()); + if (!reply) + return; + + QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute); + Q_ASSERT(response.isValid()); + bool ok; + int code = response.toInt(&ok); + if (!ok || code != 200) { + qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n") + .arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString(); + } + reply->deleteLater(); +} + +void ChatGPT::handleReadyRead() +{ + QNetworkReply *reply = qobject_cast(sender()); + if (!reply) + return; + + QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute); + Q_ASSERT(response.isValid()); + bool ok; + int code = response.toInt(&ok); + if (!ok || code != 200) { + m_responseCallback(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n") + .arg(code).arg(reply->errorString()).arg(qPrintable(reply->readAll())).toStdString()); + return; + } + + while (reply->canReadLine()) { + QString jsonData = reply->readLine().trimmed(); + if (jsonData.startsWith("data:")) + jsonData.remove(0, 5); + jsonData = jsonData.trimmed(); + if (jsonData.isEmpty()) + continue; + if (jsonData == "[DONE]") + continue; +#if defined(DEBUG) + qDebug() << "line" << qPrintable(jsonData); +#endif + QJsonParseError err; + const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err); + if (err.error != QJsonParseError::NoError) { + m_responseCallback(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n") + .arg(err.errorString()).toStdString()); + continue; + } + + const QJsonObject root = document.object(); + const QJsonArray choices = root.value("choices").toArray(); + const QJsonObject choice = choices.first().toObject(); + const QJsonObject delta = choice.value("delta").toObject(); + const QString content = delta.value("content").toString(); + Q_ASSERT(m_ctx); + Q_ASSERT(m_responseCallback); + m_responseCallback(0, content.toStdString()); + m_currentResponse += content; + } +} + +void ChatGPT::handleErrorOccurred(QNetworkReply::NetworkError code) +{ + QNetworkReply *reply = qobject_cast(sender()); + if (!reply) + return; + + qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n") + .arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString(); +} diff --git a/gpt4all-chat/chatgpt.h b/gpt4all-chat/chatgpt.h new file mode 100644 index 00000000..0a1e0d52 --- /dev/null +++ b/gpt4all-chat/chatgpt.h @@ -0,0 +1,52 @@ +#ifndef CHATGPT_H +#define CHATGPT_H + +#include +#include +#include +#include +#include "../gpt4all-backend/llmodel.h" + +class ChatGPTPrivate; +class ChatGPT : public QObject, public LLModel { + Q_OBJECT +public: + ChatGPT(); + virtual ~ChatGPT(); + + bool loadModel(const std::string &modelPath) override; + bool isModelLoaded() const override; + size_t stateSize() const override; + size_t saveState(uint8_t *dest) const override; + size_t restoreState(const uint8_t *src) override; + void prompt(const std::string &prompt, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, + PromptContext &ctx) override; + void setThreadCount(int32_t n_threads) override; + int32_t threadCount() override; + + void setModelName(const QString &modelName) { m_modelName = modelName; } + void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; } + +protected: + void recalculateContext(PromptContext &promptCtx, + std::function recalculate) override {} + +private Q_SLOTS: + void handleFinished(); + void handleReadyRead(); + void handleErrorOccurred(QNetworkReply::NetworkError code); + +private: + PromptContext *m_ctx; + std::function m_responseCallback; + QString m_modelName; + QString m_apiKey; + QList m_context; + QString m_currentResponse; + QNetworkAccessManager m_networkManager; +}; + +#endif // CHATGPT_H diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index e3268189..298010d8 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -5,6 +5,7 @@ #include "../gpt4all-backend/gptj.h" #include "../gpt4all-backend/llamamodel.h" #include "../gpt4all-backend/mpt.h" +#include "chatgpt.h" #include #include @@ -21,17 +22,15 @@ #define GPTJ_INTERNAL_STATE_VERSION 0 #define LLAMA_INTERNAL_STATE_VERSION 0 -static QString modelFilePath(const QString &modelName) +static QString modelFilePath(const QString &modelName, bool isChatGPT) { - QString appPath = QCoreApplication::applicationDirPath() - + "/ggml-" + modelName + ".bin"; + QString modelFilename = isChatGPT ? modelName + ".txt" : "/ggml-" + modelName + ".bin"; + QString appPath = QCoreApplication::applicationDirPath() + modelFilename; QFileInfo infoAppPath(appPath); if (infoAppPath.exists()) return appPath; - QString downloadPath = Download::globalInstance()->downloadLocalModelsPath() - + "/ggml-" + modelName + ".bin"; - + QString downloadPath = Download::globalInstance()->downloadLocalModelsPath() + modelFilename; QFileInfo infoLocalPath(downloadPath); if (infoLocalPath.exists()) return downloadPath; @@ -139,7 +138,8 @@ bool ChatLLM::loadModel(const QString &modelName) if (isModelLoaded() && m_modelName == modelName) return true; - QString filePath = modelFilePath(modelName); + const bool isChatGPT = modelName.startsWith("chatgpt-"); + QString filePath = modelFilePath(modelName, isChatGPT); QFileInfo fileInfo(filePath); // We have a live model, but it isn't the one we want @@ -198,25 +198,42 @@ bool ChatLLM::loadModel(const QString &modelName) m_modelInfo.fileInfo = fileInfo; if (fileInfo.exists()) { - auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); - uint32_t magic; - fin.read((char *) &magic, sizeof(magic)); - fin.seekg(0); - fin.close(); - const bool isGPTJ = magic == 0x67676d6c; - const bool isMPT = magic == 0x67676d6d; - if (isGPTJ) { - m_modelType = LLModelType::GPTJ_; - m_modelInfo.model = new GPTJ; - m_modelInfo.model->loadModel(filePath.toStdString()); - } else if (isMPT) { - m_modelType = LLModelType::MPT_; - m_modelInfo.model = new MPT; - m_modelInfo.model->loadModel(filePath.toStdString()); + if (isChatGPT) { + QString apiKey; + QString chatGPTModel = fileInfo.completeBaseName().remove(0, 8); // remove the chatgpt- prefix + { + QFile file(filePath); + file.open(QIODeviceBase::ReadOnly | QIODeviceBase::Text); + QTextStream stream(&file); + apiKey = stream.readAll(); + file.close(); + } + m_modelType = LLModelType::CHATGPT_; + ChatGPT *model = new ChatGPT(); + model->setModelName(chatGPTModel); + model->setAPIKey(apiKey); + m_modelInfo.model = model; } else { - m_modelType = LLModelType::LLAMA_; - m_modelInfo.model = new LLamaModel; - m_modelInfo.model->loadModel(filePath.toStdString()); + auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + fin.seekg(0); + fin.close(); + const bool isGPTJ = magic == 0x67676d6c; + const bool isMPT = magic == 0x67676d6d; + if (isGPTJ) { + m_modelType = LLModelType::GPTJ_; + m_modelInfo.model = new GPTJ; + m_modelInfo.model->loadModel(filePath.toStdString()); + } else if (isMPT) { + m_modelType = LLModelType::MPT_; + m_modelInfo.model = new MPT; + m_modelInfo.model->loadModel(filePath.toStdString()); + } else { + m_modelType = LLModelType::LLAMA_; + m_modelInfo.model = new LLamaModel; + m_modelInfo.model->loadModel(filePath.toStdString()); + } } #if defined(DEBUG_MODEL_LOADING) qDebug() << "new model" << m_chat->id() << m_modelInfo.model; @@ -241,8 +258,10 @@ bool ChatLLM::loadModel(const QString &modelName) emit modelLoadingError(error); } - if (m_modelInfo.model) - setModelName(fileInfo.completeBaseName().remove(0, 5)); // remove the ggml- prefix + if (m_modelInfo.model) { + QString basename = fileInfo.completeBaseName(); + setModelName(isChatGPT ? basename : basename.remove(0, 5)); // remove the ggml- prefix + } return m_modelInfo.model; } @@ -440,7 +459,7 @@ void ChatLLM::forceUnloadModel() void ChatLLM::unloadModel() { - if (!isModelLoaded() || m_isServer) + if (!isModelLoaded() || m_isServer) // FIXME: What if server switches models? return; saveState(); @@ -454,7 +473,7 @@ void ChatLLM::unloadModel() void ChatLLM::reloadModel() { - if (isModelLoaded() || m_isServer) + if (isModelLoaded() || m_isServer) // FIXME: What if server switches models? return; #if defined(DEBUG_MODEL_LOADING) diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 0b1eb343..fc99d6ce 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -10,7 +10,8 @@ enum LLModelType { MPT_, GPTJ_, - LLAMA_ + LLAMA_, + CHATGPT_, }; struct LLModelInfo { diff --git a/gpt4all-chat/download.cpp b/gpt4all-chat/download.cpp index 736c8fa1..04dc5937 100644 --- a/gpt4all-chat/download.cpp +++ b/gpt4all-chat/download.cpp @@ -267,6 +267,26 @@ void Download::cancelDownload(const QString &modelFile) } } +void Download::installModel(const QString &modelFile, const QString &apiKey) +{ + Q_ASSERT(!apiKey.isEmpty()); + if (apiKey.isEmpty()) + return; + + Network::globalInstance()->sendInstallModel(modelFile); + QString filePath = downloadLocalModelsPath() + modelFile + ".txt"; + QFile file(filePath); + if (file.open(QIODeviceBase::WriteOnly | QIODeviceBase::Text)) { + QTextStream stream(&file); + stream << apiKey; + file.close(); + ModelInfo info = m_modelMap.value(modelFile); + info.installed = true; + m_modelMap.insert(modelFile, info); + emit modelListChanged(); + } +} + void Download::handleSslErrors(QNetworkReply *reply, const QList &errors) { QUrl url = reply->request().url(); @@ -372,6 +392,47 @@ void Download::parseModelsJsonFile(const QByteArray &jsonData) m_modelMap.insert(modelInfo.filename, modelInfo); } + const QString chatGPTDesc = tr("WARNING: requires personal OpenAI API key and usage of this " + "model will send your chats over the network to OpenAI. Your API key will be stored on disk " + "and only used to interact with OpenAI models. If you don't have one, you can apply for " + "an API key here."); + + { + ModelInfo modelInfo; + modelInfo.isChatGPT = true; + modelInfo.filename = "chatgpt-gpt-3.5-turbo"; + modelInfo.description = tr("OpenAI's ChatGPT model gpt-3.5-turbo. ") + chatGPTDesc; + modelInfo.requires = "2.4.2"; + QString filePath = downloadLocalModelsPath() + modelInfo.filename + ".txt"; + QFileInfo info(filePath); + modelInfo.installed = info.exists(); + m_modelMap.insert(modelInfo.filename, modelInfo); + } + + { + ModelInfo modelInfo; + modelInfo.isChatGPT = true; + modelInfo.filename = "chatgpt-gpt-4"; + modelInfo.description = tr("OpenAI's ChatGPT model gpt-4. ") + chatGPTDesc; + modelInfo.requires = "2.4.2"; + QString filePath = downloadLocalModelsPath() + modelInfo.filename + ".txt"; + QFileInfo info(filePath); + modelInfo.installed = info.exists(); + m_modelMap.insert(modelInfo.filename, modelInfo); + } + + { + ModelInfo modelInfo; + modelInfo.isChatGPT = true; + modelInfo.filename = "chatgpt-text-davinci-003"; + modelInfo.description = tr("OpenAI's ChatGPT model text-davinci-003. ") + chatGPTDesc; + modelInfo.requires = "2.4.2"; + QString filePath = downloadLocalModelsPath() + modelInfo.filename + ".txt"; + QFileInfo info(filePath); + modelInfo.installed = info.exists(); + m_modelMap.insert(modelInfo.filename, modelInfo); + } + // remove ggml- prefix and .bin suffix Q_ASSERT(defaultModel.startsWith("ggml-")); defaultModel = defaultModel.remove(0, 5); diff --git a/gpt4all-chat/download.h b/gpt4all-chat/download.h index 638bae43..1310d99c 100644 --- a/gpt4all-chat/download.h +++ b/gpt4all-chat/download.h @@ -20,6 +20,7 @@ struct ModelInfo { Q_PROPERTY(bool bestGPTJ MEMBER bestGPTJ) Q_PROPERTY(bool bestLlama MEMBER bestLlama) Q_PROPERTY(bool bestMPT MEMBER bestMPT) + Q_PROPERTY(bool isChatGPT MEMBER isChatGPT) Q_PROPERTY(QString description MEMBER description) Q_PROPERTY(QString requires MEMBER requires) @@ -33,6 +34,7 @@ public: bool bestGPTJ = false; bool bestLlama = false; bool bestMPT = false; + bool isChatGPT = false; QString description; QString requires; }; @@ -88,6 +90,7 @@ public: Q_INVOKABLE void updateReleaseNotes(); Q_INVOKABLE void downloadModel(const QString &modelFile); Q_INVOKABLE void cancelDownload(const QString &modelFile); + Q_INVOKABLE void installModel(const QString &modelFile, const QString &apiKey); Q_INVOKABLE QString defaultLocalModelsPath() const; Q_INVOKABLE QString downloadLocalModelsPath() const; Q_INVOKABLE void setDownloadLocalModelsPath(const QString &modelPath); diff --git a/gpt4all-chat/network.cpp b/gpt4all-chat/network.cpp index 9a0eda33..3471e53d 100644 --- a/gpt4all-chat/network.cpp +++ b/gpt4all-chat/network.cpp @@ -240,6 +240,16 @@ void Network::sendModelDownloaderDialog() sendMixpanelEvent("download_dialog"); } +void Network::sendInstallModel(const QString &model) +{ + if (!m_usageStatsActive) + return; + KeyValue kv; + kv.key = QString("model"); + kv.value = QJsonValue(model); + sendMixpanelEvent("install_model", QVector{kv}); +} + void Network::sendDownloadStarted(const QString &model) { if (!m_usageStatsActive) diff --git a/gpt4all-chat/network.h b/gpt4all-chat/network.h index 1c9de2df..50e32fd4 100644 --- a/gpt4all-chat/network.h +++ b/gpt4all-chat/network.h @@ -41,6 +41,7 @@ public Q_SLOTS: void sendCheckForUpdates(); Q_INVOKABLE void sendModelDownloaderDialog(); Q_INVOKABLE void sendResetContext(int conversationLength); + void sendInstallModel(const QString &model); void sendDownloadStarted(const QString &model); void sendDownloadCanceled(const QString &model); void sendDownloadError(const QString &model, int code, const QString &errorString); diff --git a/gpt4all-chat/qml/ModelDownloaderDialog.qml b/gpt4all-chat/qml/ModelDownloaderDialog.qml index 0c2a58a5..1de3f1dd 100644 --- a/gpt4all-chat/qml/ModelDownloaderDialog.qml +++ b/gpt4all-chat/qml/ModelDownloaderDialog.qml @@ -82,7 +82,7 @@ Dialog { id: modelName objectName: "modelName" property string filename: modelData.filename - text: filename.slice(5, filename.length - 4) + text: !modelData.isChatGPT ? filename.slice(5, filename.length - 4) : filename padding: 20 anchors.top: parent.top anchors.left: parent.left @@ -102,10 +102,13 @@ Dialog { anchors.left: modelName.left anchors.right: parent.right wrapMode: Text.WordWrap + textFormat: Text.StyledText color: theme.textColor + linkColor: theme.textColor Accessible.role: Accessible.Paragraph Accessible.name: qsTr("Description") Accessible.description: qsTr("The description of the file") + onLinkActivated: Qt.openUrlExternally(link) } Text { @@ -220,6 +223,53 @@ Dialog { Accessible.description: qsTr("Whether the file is already installed on your system") } + Item { + visible: modelData.isChatGPT && !modelData.installed + anchors.top: modelName.top + anchors.topMargin: 15 + anchors.right: parent.right + + TextField { + id: openaiKey + anchors.right: installButton.left + anchors.rightMargin: 15 + color: theme.textColor + background: Rectangle { + color: theme.backgroundLighter + radius: 10 + } + placeholderText: qsTr("enter $OPENAI_API_KEY") + placeholderTextColor: theme.backgroundLightest + Accessible.role: Accessible.EditableText + Accessible.name: placeholderText + Accessible.description: qsTr("Whether the file hash is being calculated") + } + + Button { + id: installButton + contentItem: Text { + color: openaiKey.text === "" ? theme.backgroundLightest : theme.textColor + text: "Install" + } + enabled: openaiKey.text !== "" + anchors.right: parent.right + anchors.rightMargin: 20 + background: Rectangle { + opacity: .5 + border.color: theme.backgroundLightest + border.width: 1 + radius: 10 + color: theme.backgroundLight + } + onClicked: { + Download.installModel(modelData.filename, openaiKey.text); + } + Accessible.role: Accessible.Button + Accessible.name: qsTr("Install button") + Accessible.description: qsTr("Install button to install chatgpt model") + } + } + Button { id: downloadButton contentItem: Text { @@ -230,7 +280,7 @@ Dialog { anchors.right: parent.right anchors.topMargin: 15 anchors.rightMargin: 20 - visible: !modelData.installed && !modelData.calcHash + visible: !modelData.isChatGPT && !modelData.installed && !modelData.calcHash onClicked: { if (!downloading) { downloading = true;