diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 0ec305f2..bcb96b60 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -36,7 +36,7 @@ configure_file( "${CMAKE_CURRENT_BINARY_DIR}/config.h" ) -find_package(Qt6 6.2 COMPONENTS Core Quick QuickDialogs2 Svg REQUIRED) +find_package(Qt6 6.5 COMPONENTS Core Quick QuickDialogs2 Svg HttpServer REQUIRED) # Get the Qt6Core target properties get_target_property(Qt6Core_INCLUDE_DIRS Qt6::Core INTERFACE_INCLUDE_DIRECTORIES) @@ -64,6 +64,7 @@ qt_add_executable(chat download.h download.cpp network.h network.cpp llm.h llm.cpp + server.h server.cpp sysinfo.h ) @@ -118,7 +119,7 @@ endif() target_compile_definitions(chat PRIVATE $<$,$>:QT_QML_DEBUG>) target_link_libraries(chat - PRIVATE Qt6::Quick Qt6::Svg) + PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer) target_link_libraries(chat PRIVATE llmodel) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 824307eb..6e99c9e2 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -11,6 +11,20 @@ Chat::Chat(QObject *parent) , m_responseInProgress(false) , m_creationDate(QDateTime::currentSecsSinceEpoch()) , m_llmodel(new ChatLLM(this)) + , m_isServer(false) +{ + connectLLM(); +} + +Chat::Chat(bool isServer, QObject *parent) + : QObject(parent) + , m_id(Network::globalInstance()->generateUniqueId()) + , m_name(tr("Server Chat")) + , m_chatModel(new ChatModel(this)) + , m_responseInProgress(false) + , m_creationDate(QDateTime::currentSecsSinceEpoch()) + , m_llmodel(new Server(this)) + , m_isServer(true) { connectLLM(); } @@ -138,11 +152,19 @@ void Chat::setModelName(const QString &modelName) void Chat::newPromptResponsePair(const QString &prompt) { + m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); m_chatModel->appendPrompt(tr("Prompt: "), prompt); m_chatModel->appendResponse(tr("Response: "), prompt); emit resetResponseRequested(); // blocking queued connection } +void Chat::serverNewPromptResponsePair(const QString &prompt) +{ + m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); + m_chatModel->appendPrompt(tr("Prompt: "), prompt); + m_chatModel->appendResponse(tr("Response: "), prompt); +} + bool Chat::isRecalc() const { return m_llmodel->isRecalc(); @@ -236,6 +258,17 @@ QList Chat::modelList() const QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); QString localPath = Download::globalInstance()->downloadLocalModelsPath(); + QSettings settings; + settings.sync(); + // The user default model can be set by the user in the settings dialog. The "default" user + // default model is "Application default" which signals we should use the default model that was + // specified by the models.json file. + QString defaultModel = settings.value("userDefaultModel").toString(); + if (defaultModel.isEmpty() || defaultModel == "Application default") + defaultModel = settings.value("defaultModel").toString(); + + QString currentModelName = modelName().isEmpty() ? defaultModel : modelName(); + { QDir dir(exePath); dir.setNameFilters(QStringList() << "ggml-*.bin"); @@ -245,7 +278,7 @@ QList Chat::modelList() const QFileInfo info(filePath); QString name = info.completeBaseName().remove(0, 5); if (info.exists()) { - if (name == modelName()) + if (name == currentModelName) list.prepend(name); else list.append(name); @@ -262,7 +295,7 @@ QList Chat::modelList() const QFileInfo info(filePath); QString name = info.completeBaseName().remove(0, 5); if (info.exists() && !list.contains(name)) { // don't allow duplicates - if (name == modelName()) + if (name == currentModelName) list.prepend(name); else list.append(name); diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index 2970ad6c..b3275caf 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -7,6 +7,7 @@ #include "chatllm.h" #include "chatmodel.h" +#include "server.h" class Chat : public QObject { @@ -20,11 +21,13 @@ class Chat : public QObject Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) Q_PROPERTY(QList modelList READ modelList NOTIFY modelListChanged) + Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged) QML_ELEMENT QML_UNCREATABLE("Only creatable from c++!") public: explicit Chat(QObject *parent = nullptr); + explicit Chat(bool isServer, QObject *parent = nullptr); virtual ~Chat(); void connectLLM(); @@ -61,6 +64,10 @@ public: bool deserialize(QDataStream &stream, int version); QList modelList() const; + bool isServer() const { return m_isServer; } + +public Q_SLOTS: + void serverNewPromptResponsePair(const QString &prompt); Q_SIGNALS: void idChanged(); @@ -85,6 +92,7 @@ Q_SIGNALS: void generateNameRequested(); void modelListChanged(); void modelLoadingError(const QString &error); + void isServerChanged(); private Q_SLOTS: void handleResponseChanged(); @@ -103,6 +111,7 @@ private: bool m_responseInProgress; qint64 m_creationDate; ChatLLM *m_llmodel; + bool m_isServer; }; #endif // CHAT_H diff --git a/gpt4all-chat/chatlistmodel.cpp b/gpt4all-chat/chatlistmodel.cpp index 3fd2246f..a0fe17f6 100644 --- a/gpt4all-chat/chatlistmodel.cpp +++ b/gpt4all-chat/chatlistmodel.cpp @@ -1,5 +1,6 @@ #include "chatlistmodel.h" #include "download.h" +#include "llm.h" #include #include @@ -11,6 +12,7 @@ ChatListModel::ChatListModel(QObject *parent) : QAbstractListModel(parent) , m_newChat(nullptr) , m_dummyChat(nullptr) + , m_serverChat(nullptr) , m_currentChat(nullptr) , m_shouldSaveChats(false) { @@ -243,4 +245,16 @@ void ChatListModel::chatsRestoredFinished() if (m_chats.isEmpty()) addChat(); + + addServerChat(); } + +void ChatListModel::handleServerEnabledChanged() +{ + if (LLM::globalInstance()->serverEnabled() || m_serverChat != m_currentChat) + return; + + Chat *nextChat = get(0); + Q_ASSERT(nextChat && nextChat != m_serverChat); + setCurrentChat(nextChat); +} \ No newline at end of file diff --git a/gpt4all-chat/chatlistmodel.h b/gpt4all-chat/chatlistmodel.h index c695e05d..e8858b60 100644 --- a/gpt4all-chat/chatlistmodel.h +++ b/gpt4all-chat/chatlistmodel.h @@ -94,6 +94,19 @@ public: emit currentChatChanged(); } + Q_INVOKABLE void addServerChat() + { + // Create a new dummy chat pointer and don't connect it + if (m_serverChat) + return; + + m_serverChat = new Chat(true /*isServer*/, this); + beginInsertRows(QModelIndex(), m_chats.size(), m_chats.size()); + m_chats.append(m_serverChat); + endInsertRows(); + emit countChanged(); + } + void setNewChat(Chat* chat) { // Don't add a new chat if we already have one @@ -161,7 +174,7 @@ public: m_currentChat->unloadModel(); m_currentChat = chat; - if (!m_currentChat->isModelLoaded()) + if (!m_currentChat->isModelLoaded() && m_currentChat != m_serverChat) m_currentChat->reloadModel(); emit currentChatChanged(); } @@ -179,6 +192,9 @@ public: void restoreChat(Chat *chat); void chatsRestoredFinished(); +public Q_SLOTS: + void handleServerEnabledChanged(); + Q_SIGNALS: void countChanged(); void currentChatChanged(); @@ -226,6 +242,7 @@ private: bool m_shouldSaveChats; Chat* m_newChat; Chat* m_dummyChat; + Chat* m_serverChat; Chat* m_currentChat; QList m_chats; }; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index bdb843eb..34c604fb 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -41,6 +41,7 @@ ChatLLM::ChatLLM(Chat *parent) : QObject{nullptr} , m_llmodel(nullptr) , m_promptResponseTokens(0) + , m_promptTokens(0) , m_responseLogits(0) , m_isRecalc(false) , m_chat(parent) @@ -49,6 +50,7 @@ ChatLLM::ChatLLM(Chat *parent) connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); + connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted); m_llmThread.setObjectName(m_chat->id()); m_llmThread.start(); } @@ -69,18 +71,7 @@ bool ChatLLM::loadDefaultModel() &ChatLLM::loadDefaultModel, Qt::SingleShotConnection); return false; } - - QSettings settings; - settings.sync(); - // The user default model can be set by the user in the settings dialog. The "default" user - // default model is "Application default" which signals we should use the default model that was - // specified by the models.json file. - QString defaultModel = settings.value("userDefaultModel").toString(); - if (defaultModel.isEmpty() || !models.contains(defaultModel) || defaultModel == "Application default") - defaultModel = settings.value("defaultModel").toString(); - if (defaultModel.isEmpty() || !models.contains(defaultModel)) - defaultModel = models.first(); - return loadModel(defaultModel); + return loadModel(models.first()); } bool ChatLLM::loadModel(const QString &modelName) @@ -89,7 +80,7 @@ bool ChatLLM::loadModel(const QString &modelName) return true; if (isModelLoaded()) { - resetContextPrivate(); + resetContextProtected(); delete m_llmodel; m_llmodel = nullptr; emit isModelLoadedChanged(); @@ -161,6 +152,7 @@ void ChatLLM::regenerateResponse() m_ctx.logits.erase(m_ctx.logits.end() -= m_responseLogits, m_ctx.logits.end()); m_ctx.tokens.erase(m_ctx.tokens.end() -= m_promptResponseTokens, m_ctx.tokens.end()); m_promptResponseTokens = 0; + m_promptTokens = 0; m_responseLogits = 0; m_response = std::string(); emit responseChanged(); @@ -168,6 +160,7 @@ void ChatLLM::regenerateResponse() void ChatLLM::resetResponse() { + m_promptTokens = 0; m_promptResponseTokens = 0; m_responseLogits = 0; m_response = std::string(); @@ -176,11 +169,11 @@ void ChatLLM::resetResponse() void ChatLLM::resetContext() { - resetContextPrivate(); + resetContextProtected(); emit sendResetContext(); } -void ChatLLM::resetContextPrivate() +void ChatLLM::resetContextProtected() { regenerateResponse(); m_ctx = LLModel::PromptContext(); @@ -235,6 +228,7 @@ bool ChatLLM::handlePrompt(int32_t token) #if defined(DEBUG) qDebug() << "chatllm prompt process" << m_chat->id() << token; #endif + ++m_promptTokens; ++m_promptResponseTokens; return !m_stopGenerating; } diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index ef2c3bd3..59d480f3 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -70,9 +70,15 @@ Q_SIGNALS: void sendResetContext(); void generatedNameChanged(); void stateChanged(); + void threadStarted(); + +protected: + LLModel::PromptContext m_ctx; + quint32 m_promptTokens; + quint32 m_promptResponseTokens; + void resetContextProtected(); private: - void resetContextPrivate(); bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); bool handleRecalculate(bool isRecalc); @@ -83,11 +89,9 @@ private: void restoreState(); private: - LLModel::PromptContext m_ctx; LLModel *m_llmodel; std::string m_response; std::string m_nameResponse; - quint32 m_promptResponseTokens; quint32 m_responseLogits; QString m_modelName; ModelType m_modelType; diff --git a/gpt4all-chat/llm.cpp b/gpt4all-chat/llm.cpp index e94c461b..19fc9f77 100644 --- a/gpt4all-chat/llm.cpp +++ b/gpt4all-chat/llm.cpp @@ -22,10 +22,13 @@ LLM::LLM() : QObject{nullptr} , m_chatListModel(new ChatListModel(this)) , m_threadCount(std::min(4, (int32_t) std::thread::hardware_concurrency())) + , m_serverEnabled(false) , m_compatHardware(true) { connect(QCoreApplication::instance(), &QCoreApplication::aboutToQuit, this, &LLM::aboutToQuit); + connect(this, &LLM::serverEnabledChanged, + m_chatListModel, &ChatListModel::handleServerEnabledChanged); #if defined(__x86_64__) || defined(__i386__) if (QString(GPT4ALL_AVX_ONLY) == "OFF") { @@ -73,6 +76,19 @@ void LLM::setThreadCount(int32_t n_threads) emit threadCountChanged(); } +bool LLM::serverEnabled() const +{ + return m_serverEnabled; +} + +void LLM::setServerEnabled(bool enabled) +{ + if (m_serverEnabled == enabled) + return; + m_serverEnabled = enabled; + emit serverEnabledChanged(); +} + void LLM::aboutToQuit() { m_chatListModel->saveChats(); diff --git a/gpt4all-chat/llm.h b/gpt4all-chat/llm.h index ac12981d..3674406b 100644 --- a/gpt4all-chat/llm.h +++ b/gpt4all-chat/llm.h @@ -10,6 +10,7 @@ class LLM : public QObject Q_OBJECT Q_PROPERTY(ChatListModel *chatListModel READ chatListModel NOTIFY chatListModelChanged) Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) + Q_PROPERTY(bool serverEnabled READ serverEnabled WRITE setServerEnabled NOTIFY serverEnabledChanged) Q_PROPERTY(bool compatHardware READ compatHardware NOTIFY compatHardwareChanged) public: @@ -18,6 +19,9 @@ public: ChatListModel *chatListModel() const { return m_chatListModel; } int32_t threadCount() const; void setThreadCount(int32_t n_threads); + bool serverEnabled() const; + void setServerEnabled(bool enabled); + bool compatHardware() const { return m_compatHardware; } Q_INVOKABLE bool checkForUpdates() const; @@ -25,6 +29,7 @@ public: Q_SIGNALS: void chatListModelChanged(); void threadCountChanged(); + void serverEnabledChanged(); void compatHardwareChanged(); private Q_SLOTS: @@ -33,6 +38,7 @@ private Q_SLOTS: private: ChatListModel *m_chatListModel; int32_t m_threadCount; + bool m_serverEnabled; bool m_compatHardware; private: diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml index 6ab92df0..7e81a6c7 100644 --- a/gpt4all-chat/main.qml +++ b/gpt4all-chat/main.qml @@ -122,7 +122,7 @@ Window { Item { anchors.centerIn: parent height: childrenRect.height - visible: currentChat.isModelLoaded + visible: currentChat.isModelLoaded || currentChat.isServer Label { id: modelLabel @@ -142,6 +142,7 @@ Window { anchors.top: modelLabel.top anchors.bottom: modelLabel.bottom anchors.horizontalCenter: parent.horizontalCenter + enabled: !currentChat.isServer font.pixelSize: theme.fontSizeLarge spacing: 0 model: currentChat.modelList @@ -206,8 +207,8 @@ Window { BusyIndicator { anchors.centerIn: parent - visible: !currentChat.isModelLoaded - running: !currentChat.isModelLoaded + visible: !currentChat.isModelLoaded && !currentChat.isServer + running: !currentChat.isModelLoaded && !currentChat.isServer Accessible.role: Accessible.Animation Accessible.name: qsTr("Busy indicator") Accessible.description: qsTr("Displayed when the model is loading") @@ -570,13 +571,13 @@ Window { anchors.left: parent.left anchors.right: parent.right anchors.top: parent.top - anchors.bottom: textInputView.top - anchors.bottomMargin: 30 + anchors.bottom: !currentChat.isServer ? textInputView.top : parent.bottom + anchors.bottomMargin: !currentChat.isServer ? 30 : 0 ScrollBar.vertical.policy: ScrollBar.AlwaysOn Rectangle { anchors.fill: parent - color: theme.backgroundLighter + color: currentChat.isServer ? theme.backgroundDark : theme.backgroundLighter ListView { id: listView @@ -598,7 +599,9 @@ Window { cursorVisible: currentResponse ? currentChat.responseInProgress : false cursorPosition: text.length background: Rectangle { - color: name === qsTr("Response: ") ? theme.backgroundLighter : theme.backgroundLight + color: name === qsTr("Response: ") + ? (currentChat.isServer ? theme.backgroundDarkest : theme.backgroundLighter) + : (currentChat.isServer ? theme.backgroundDark : theme.backgroundLight) } Accessible.role: Accessible.Paragraph @@ -757,7 +760,7 @@ Window { } Button { - visible: chatModel.count + visible: chatModel.count && !currentChat.isServer Image { anchors.verticalCenter: parent.verticalCenter anchors.left: parent.left @@ -819,13 +822,14 @@ Window { anchors.bottom: parent.bottom anchors.margins: 30 height: Math.min(contentHeight, 200) + visible: !currentChat.isServer TextArea { id: textInput color: theme.textColor padding: 20 rightPadding: 40 - enabled: currentChat.isModelLoaded + enabled: currentChat.isModelLoaded && !currentChat.isServer wrapMode: Text.WordWrap font.pixelSize: theme.fontSizeLarge placeholderText: qsTr("Send a message...") @@ -850,12 +854,6 @@ Window { return currentChat.stopGenerating() - - if (chatModel.count) { - var index = Math.max(0, chatModel.count - 1); - var listElement = chatModel.get(index); - chatModel.updateCurrentResponse(index, false); - } currentChat.newPromptResponsePair(textInput.text); currentChat.prompt(textInput.text, settingsDialog.promptTemplate, settingsDialog.maxLength, @@ -876,6 +874,7 @@ Window { anchors.rightMargin: 15 width: 30 height: 30 + visible: !currentChat.isServer background: Image { anchors.centerIn: parent diff --git a/gpt4all-chat/qml/ChatDrawer.qml b/gpt4all-chat/qml/ChatDrawer.qml index d3298f1a..8db645e6 100644 --- a/gpt4all-chat/qml/ChatDrawer.qml +++ b/gpt4all-chat/qml/ChatDrawer.qml @@ -83,9 +83,11 @@ Drawer { height: chatName.height opacity: 0.9 property bool isCurrent: LLM.chatListModel.currentChat === LLM.chatListModel.get(index) + property bool isServer: LLM.chatListModel.get(index) && LLM.chatListModel.get(index).isServer property bool trashQuestionDisplayed: false + visible: !isServer || LLM.serverEnabled z: isCurrent ? 199 : 1 - color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter + color: isServer ? theme.backgroundDarkest : (index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter) border.width: isCurrent border.color: chatName.readOnly ? theme.assistantColor : theme.userColor TextField { @@ -149,7 +151,7 @@ Drawer { id: editButton width: 30 height: 30 - visible: isCurrent + visible: isCurrent && !isServer opacity: trashQuestionDisplayed ? 0.5 : 1.0 background: Image { width: 30 @@ -166,10 +168,10 @@ Drawer { Accessible.description: qsTr("Provides a button to edit the chat name") } Button { - id: c + id: trashButton width: 30 height: 30 - visible: isCurrent + visible: isCurrent && !isServer background: Image { width: 30 height: 30 diff --git a/gpt4all-chat/qml/SettingsDialog.qml b/gpt4all-chat/qml/SettingsDialog.qml index c9f3557f..7c2f6d59 100644 --- a/gpt4all-chat/qml/SettingsDialog.qml +++ b/gpt4all-chat/qml/SettingsDialog.qml @@ -40,6 +40,7 @@ Dialog { property int defaultRepeatPenaltyTokens: 64 property int defaultThreadCount: 0 property bool defaultSaveChats: false + property bool defaultServerChat: false property string defaultPromptTemplate: "### Human: %1 ### Assistant:\n" @@ -56,6 +57,7 @@ Dialog { property alias repeatPenaltyTokens: settings.repeatPenaltyTokens property alias threadCount: settings.threadCount property alias saveChats: settings.saveChats + property alias serverChat: settings.serverChat property alias modelPath: settings.modelPath property alias userDefaultModel: settings.userDefaultModel @@ -68,6 +70,7 @@ Dialog { property int promptBatchSize: settingsDialog.defaultPromptBatchSize property int threadCount: settingsDialog.defaultThreadCount property bool saveChats: settingsDialog.defaultSaveChats + property bool serverChat: settingsDialog.defaultServerChat property real repeatPenalty: settingsDialog.defaultRepeatPenalty property int repeatPenaltyTokens: settingsDialog.defaultRepeatPenaltyTokens property string promptTemplate: settingsDialog.defaultPromptTemplate @@ -91,15 +94,18 @@ Dialog { settings.modelPath = settingsDialog.defaultModelPath settings.threadCount = defaultThreadCount settings.saveChats = defaultSaveChats + settings.serverChat = defaultServerChat settings.userDefaultModel = defaultUserDefaultModel Download.downloadLocalModelsPath = settings.modelPath LLM.threadCount = settings.threadCount + LLM.serverEnabled = settings.serverChat LLM.chatListModel.shouldSaveChats = settings.saveChats settings.sync() } Component.onCompleted: { LLM.threadCount = settings.threadCount + LLM.serverEnabled = settings.serverChat LLM.chatListModel.shouldSaveChats = settings.saveChats Download.downloadLocalModelsPath = settings.modelPath } @@ -796,8 +802,60 @@ Dialog { leftPadding: saveChatsBox.indicator.width + saveChatsBox.spacing } } - Button { + Label { + id: serverChatLabel + text: qsTr("Enable web server:") + color: theme.textColor Layout.row: 5 + Layout.column: 0 + } + CheckBox { + id: serverChatBox + Layout.row: 5 + Layout.column: 1 + checked: settings.serverChat + onClicked: { + settingsDialog.serverChat = serverChatBox.checked + LLM.serverEnabled = serverChatBox.checked + settings.sync() + } + + ToolTip.text: qsTr("WARNING: This enables the gui to act as a local web server for AI API requests") + ToolTip.visible: hovered + + background: Rectangle { + color: "transparent" + } + + indicator: Rectangle { + implicitWidth: 26 + implicitHeight: 26 + x: serverChatBox.leftPadding + y: parent.height / 2 - height / 2 + border.color: theme.dialogBorder + color: "transparent" + + Rectangle { + width: 14 + height: 14 + x: 6 + y: 6 + color: theme.textColor + visible: serverChatBox.checked + } + } + + contentItem: Text { + text: serverChatBox.text + font: serverChatBox.font + opacity: enabled ? 1.0 : 0.3 + color: theme.textColor + verticalAlignment: Text.AlignVCenter + leftPadding: serverChatBox.indicator.width + serverChatBox.spacing + } + } + Button { + Layout.row: 6 Layout.column: 1 Layout.fillWidth: true padding: 10 diff --git a/gpt4all-chat/server.cpp b/gpt4all-chat/server.cpp new file mode 100644 index 00000000..1c2484df --- /dev/null +++ b/gpt4all-chat/server.cpp @@ -0,0 +1,356 @@ +#include "server.h" +#include "llm.h" +#include "download.h" + +#include +#include +#include +#include +#include + +//#define DEBUG + +static inline QString modelToName(const ModelInfo &info) +{ + QString modelName = info.filename; + Q_ASSERT(modelName.startsWith("ggml-")); + modelName = modelName.remove(0, 5); + Q_ASSERT(modelName.endsWith(".bin")); + modelName.chop(4); + return modelName; +} + +static inline QJsonObject modelToJson(const ModelInfo &info) +{ + QString modelName = modelToName(info); + + QJsonObject model; + model.insert("id", modelName); + model.insert("object", "model"); + model.insert("created", "who can keep track?"); + model.insert("owned_by", "humanity"); + model.insert("root", modelName); + model.insert("parent", QJsonValue::Null); + + QJsonArray permissions; + QJsonObject permissionObj; + permissionObj.insert("id", "foobarbaz"); + permissionObj.insert("object", "model_permission"); + permissionObj.insert("created", "does it really matter?"); + permissionObj.insert("allow_create_engine", false); + permissionObj.insert("allow_sampling", false); + permissionObj.insert("allow_logprobs", false); + permissionObj.insert("allow_search_indices", false); + permissionObj.insert("allow_view", true); + permissionObj.insert("allow_fine_tuning", false); + permissionObj.insert("organization", "*"); + permissionObj.insert("group", QJsonValue::Null); + permissionObj.insert("is_blocking", false); + permissions.append(permissionObj); + model.insert("permissions", permissions); + return model; +} + +Server::Server(Chat *chat) + : ChatLLM(chat) + , m_chat(chat) + , m_server(nullptr) +{ + connect(this, &Server::threadStarted, this, &Server::start); +} + +Server::~Server() +{ +} + +void Server::start() +{ + m_server = new QHttpServer(this); + if (!m_server->listen(QHostAddress::LocalHost, 4891)) { + qWarning() << "ERROR: Unable to start the server"; + return; + } + + m_server->route("/v1/models", QHttpServerRequest::Method::Get, + [](const QHttpServerRequest &request) { + if (!LLM::globalInstance()->serverEnabled()) + return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); + + const QList modelList = Download::globalInstance()->modelList(); + QJsonObject root; + root.insert("object", "list"); + QJsonArray data; + for (const ModelInfo &info : modelList) { + if (!info.installed) + continue; + data.append(modelToJson(info)); + } + root.insert("data", data); + return QHttpServerResponse(root); + } + ); + + m_server->route("/v1/models/", QHttpServerRequest::Method::Get, + [](const QString &model, const QHttpServerRequest &request) { + if (!LLM::globalInstance()->serverEnabled()) + return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); + + const QList modelList = Download::globalInstance()->modelList(); + QJsonObject object; + for (const ModelInfo &info : modelList) { + if (!info.installed) + continue; + + QString modelName = modelToName(info); + if (model == modelName) { + object = modelToJson(info); + break; + } + } + return QHttpServerResponse(object); + } + ); + + m_server->route("/v1/completions", QHttpServerRequest::Method::Post, + [=](const QHttpServerRequest &request) { + if (!LLM::globalInstance()->serverEnabled()) + return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); + return handleCompletionRequest(request, false); + } + ); + + m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Post, + [=](const QHttpServerRequest &request) { + if (!LLM::globalInstance()->serverEnabled()) + return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); + return handleCompletionRequest(request, true); + } + ); + + connect(this, &Server::requestServerNewPromptResponsePair, m_chat, + &Chat::serverNewPromptResponsePair, Qt::BlockingQueuedConnection); +} + +QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &request, bool isChat) +{ + // We've been asked to do a completion... + QJsonParseError err; + const QJsonDocument document = QJsonDocument::fromJson(request.body(), &err); + if (err.error || !document.isObject()) { + std::cerr << "ERROR: invalid json in completions body" << std::endl; + return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent); + } +#if defined(DEBUG) + printf("/v1/completions %s\n", qPrintable(document.toJson(QJsonDocument::Indented))); + fflush(stdout); +#endif + const QJsonObject body = document.object(); + if (!body.contains("model")) { // required + std::cerr << "ERROR: completions contains no model" << std::endl; + return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent); + } + QJsonArray messages; + if (isChat) { + if (!body.contains("messages")) { + std::cerr << "ERROR: chat completions contains no messages" << std::endl; + return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent); + } + messages = body["messages"].toArray(); + } + + const QString model = body["model"].toString(); + bool foundModel = false; + const QList modelList = Download::globalInstance()->modelList(); + for (const ModelInfo &info : modelList) { + if (!info.installed) + continue; + if (model == modelToName(info)) { + foundModel = true; + break; + } + } + + if (!foundModel) { + if (!loadDefaultModel()) { + std::cerr << "ERROR: couldn't load default model" << model.toStdString() << std::endl; + return QHttpServerResponse(QHttpServerResponder::StatusCode::BadRequest); + } + } else if (!loadModel(model)) { + std::cerr << "ERROR: couldn't load model" << model.toStdString() << std::endl; + return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError); + } + + // We only support one prompt for now + QList prompts; + if (body.contains("prompt")) { + QJsonValue promptValue = body["prompt"]; + if (promptValue.isString()) + prompts.append(promptValue.toString()); + else { + QJsonArray array = promptValue.toArray(); + for (QJsonValue v : array) + prompts.append(v.toString()); + } + } else + prompts.append(" "); + + int max_tokens = 16; + if (body.contains("max_tokens")) + max_tokens = body["max_tokens"].toInt(); + + float temperature = 1.f; + if (body.contains("temperature")) + temperature = body["temperature"].toDouble(); + + float top_p = 1.f; + if (body.contains("top_p")) + top_p = body["top_p"].toDouble(); + + int n = 1; + if (body.contains("n")) + n = body["n"].toInt(); + + int logprobs = -1; // supposed to be null by default?? + if (body.contains("logprobs")) + logprobs = body["logprobs"].toInt(); + + bool echo = false; + if (body.contains("echo")) + echo = body["echo"].toBool(); + + // We currently don't support any of the following... +#if 0 + // FIXME: Need configurable reverse prompts + QList stop; + if (body.contains("stop")) { + QJsonValue stopValue = body["stop"]; + if (stopValue.isString()) + stop.append(stopValue.toString()); + else { + QJsonArray array = stopValue.toArray(); + for (QJsonValue v : array) + stop.append(v.toString()); + } + } + + // FIXME: QHttpServer doesn't support server-sent events + bool stream = false; + if (body.contains("stream")) + stream = body["stream"].toBool(); + + // FIXME: What does this do? + QString suffix; + if (body.contains("suffix")) + suffix = body["suffix"].toString(); + + // FIXME: We don't support + float presence_penalty = 0.f; + if (body.contains("presence_penalty")) + top_p = body["presence_penalty"].toDouble(); + + // FIXME: We don't support + float frequency_penalty = 0.f; + if (body.contains("frequency_penalty")) + top_p = body["frequency_penalty"].toDouble(); + + // FIXME: We don't support + int best_of = 1; + if (body.contains("best_of")) + logprobs = body["best_of"].toInt(); + + // FIXME: We don't need + QString user; + if (body.contains("user")) + suffix = body["user"].toString(); +#endif + + QString actualPrompt = prompts.first(); + + // if we're a chat completion we have messages which means we need to prepend these to the prompt + if (!messages.isEmpty()) { + QList chats; + for (int i = 0; i < messages.count(); ++i) { + QJsonValue v = messages.at(i); + QString content = v.toObject()["content"].toString(); + if (!content.endsWith("\n") && i < messages.count() - 1) + content += "\n"; + chats.append(content); + } + actualPrompt.prepend(chats.join("\n")); + } + + // adds prompt/response items to GUI + emit requestServerNewPromptResponsePair(actualPrompt); // blocks + + // don't remember any context + resetContextProtected(); + + QSettings settings; + settings.sync(); + const QString promptTemplate = settings.value("promptTemplate", "%1").toString(); + const float top_k = settings.value("topK", m_ctx.top_k).toDouble(); + const int n_batch = settings.value("promptBatchSize", m_ctx.n_batch).toInt(); + const float repeat_penalty = settings.value("repeatPenalty", m_ctx.repeat_penalty).toDouble(); + const int repeat_last_n = settings.value("repeatPenaltyTokens", m_ctx.repeat_last_n).toInt(); + + int promptTokens = 0; + int responseTokens = 0; + QList responses; + for (int i = 0; i < n; ++i) { + if (!prompt(actualPrompt, + promptTemplate, + max_tokens /*n_predict*/, + top_k, + top_p, + temperature, + n_batch, + repeat_penalty, + repeat_last_n, + LLM::globalInstance()->threadCount())) { + + std::cerr << "ERROR: couldn't prompt model" << model.toStdString() << std::endl; + return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError); + } + QString echoedPrompt = actualPrompt; + if (!echoedPrompt.endsWith("\n")) + echoedPrompt += "\n"; + responses.append((echo ? QString("%1\n").arg(actualPrompt) : QString()) + response()); + if (!promptTokens) + promptTokens += m_promptTokens; + responseTokens += m_promptResponseTokens - m_promptTokens; + if (i != n - 1) + resetResponse(); + } + + QJsonObject responseObject; + responseObject.insert("id", "foobarbaz"); + responseObject.insert("object", "text_completion"); + responseObject.insert("created", QDateTime::currentSecsSinceEpoch()); + responseObject.insert("model", modelName()); + + QJsonArray choices; + int index = 0; + for (QString r : responses) { + QJsonObject choice; + choice.insert("text", r); + choice.insert("index", index++); + choice.insert("logprobs", QJsonValue::Null); // We don't support + choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop"); + choices.append(choice); + } + responseObject.insert("choices", choices); + + QJsonObject usage; + usage.insert("prompt_tokens", int(promptTokens)); + usage.insert("completion_tokens", int(responseTokens)); + usage.insert("total_tokens", int(promptTokens + responseTokens)); + responseObject.insert("usage", usage); + +#if defined(DEBUG) + QJsonDocument newDoc(responseObject); + printf("/v1/completions %s\n", qPrintable(newDoc.toJson(QJsonDocument::Indented))); + fflush(stdout); +#endif + + return QHttpServerResponse(responseObject); +} diff --git a/gpt4all-chat/server.h b/gpt4all-chat/server.h new file mode 100644 index 00000000..90a89cfb --- /dev/null +++ b/gpt4all-chat/server.h @@ -0,0 +1,31 @@ +#ifndef SERVER_H +#define SERVER_H + +#include "chatllm.h" + +#include +#include + +class Server : public ChatLLM +{ + Q_OBJECT + +public: + Server(Chat *parent); + virtual ~Server(); + +public Q_SLOTS: + void start(); + +Q_SIGNALS: + void requestServerNewPromptResponsePair(const QString &prompt); + +private Q_SLOTS: + QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat); + +private: + Chat *m_chat; + QHttpServer *m_server; +}; + +#endif // SERVER_H