diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 666ea17b..163b1dd3 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -57,6 +57,7 @@ void Chat::connectLLM() connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportDevice, this, &Chat::handleDeviceChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::reportFallbackReason, this, &Chat::handleFallbackReasonChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); @@ -352,6 +353,12 @@ void Chat::handleDeviceChanged(const QString &device) emit deviceChanged(); } +void Chat::handleFallbackReasonChanged(const QString &fallbackReason) +{ + m_fallbackReason = fallbackReason; + emit fallbackReasonChanged(); +} + void Chat::handleDatabaseResultsChanged(const QList &results) { m_databaseResults = results; diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index ad7e12b7..fb27fcf4 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -26,6 +26,7 @@ class Chat : public QObject Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged) Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged); Q_PROPERTY(QString device READ device NOTIFY deviceChanged); + Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY fallbackReasonChanged); QML_ELEMENT QML_UNCREATABLE("Only creatable from c++!") @@ -90,6 +91,7 @@ public: QString tokenSpeed() const { return m_tokenSpeed; } QString device() const { return m_device; } + QString fallbackReason() const { return m_fallbackReason; } public Q_SLOTS: void serverNewPromptResponsePair(const QString &prompt); @@ -118,6 +120,7 @@ Q_SIGNALS: void collectionListChanged(const QList &collectionList); void tokenSpeedChanged(); void deviceChanged(); + void fallbackReasonChanged(); private Q_SLOTS: void handleResponseChanged(const QString &response); @@ -129,6 +132,7 @@ private Q_SLOTS: void handleModelLoadingError(const QString &error); void handleTokenSpeedChanged(const QString &tokenSpeed); void handleDeviceChanged(const QString &device); + void handleFallbackReasonChanged(const QString &device); void handleDatabaseResultsChanged(const QList &results); void handleModelInfoChanged(const ModelInfo &modelInfo); void handleModelInstalled(); @@ -142,6 +146,7 @@ private: QString m_modelLoadingError; QString m_tokenSpeed; QString m_device; + QString m_fallbackReason; QString m_response; QList m_collections; ChatModel *m_chatModel; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index b0d1b6f1..a1bbb604 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -267,27 +267,46 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) if (requestedDevice != "CPU") { const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString()); std::vector availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory); + LLModel::GPUDevice *device = nullptr; + if (!availableDevices.empty() && requestedDevice == "Auto" && availableDevices.front().type == 2 /*a discrete gpu*/) { - m_llModelInfo.model->initializeGPUDevice(availableDevices.front()); - actualDevice = QString::fromStdString(availableDevices.front().name); + device = &availableDevices.front(); } else { for (LLModel::GPUDevice &d : availableDevices) { if (QString::fromStdString(d.name) == requestedDevice) { - m_llModelInfo.model->initializeGPUDevice(d); - actualDevice = QString::fromStdString(d.name); + device = &d; break; } } } + + if (!device) { + emit reportFallbackReason("
Using CPU: device not found"); + } else if (!m_llModelInfo.model->initializeGPUDevice(*device)) { + emit reportFallbackReason("
Using CPU: failed to init device"); + } else { + actualDevice = QString::fromStdString(device->name); + } } // Report which device we're actually using emit reportDevice(actualDevice); bool success = m_llModelInfo.model->loadModel(filePath.toStdString()); - if (!success && actualDevice != "CPU") { + if (actualDevice == "CPU") { + // we asked llama.cpp to use the CPU + } else if (!success) { + // llama_init_from_file returned nullptr + // this may happen because ggml_metal_add_buffer failed emit reportDevice("CPU"); + emit reportFallbackReason("
Using CPU: llama_init_from_file failed"); success = m_llModelInfo.model->loadModel(filePath.toStdString()); + } else if (!m_llModelInfo.model->usingGPUDevice()) { + // ggml_vk_init was not called in llama.cpp + // We might have had to fallback to CPU after load if the model is not possible to accelerate + // for instance if the quantization method is not supported on Vulkan yet + emit reportDevice("CPU"); + emit reportFallbackReason("
Using CPU: unsupported quantization type"); } MySettings::globalInstance()->setAttemptModelLoad(QString()); @@ -299,11 +318,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) m_llModelInfo = LLModelInfo(); emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename())); } else { - // We might have had to fallback to CPU after load if the model is not possible to accelerate - // for instance if the quantization method is not supported on Vulkan yet - if (actualDevice != "CPU" && !m_llModelInfo.model->usingGPUDevice()) - emit reportDevice("CPU"); - switch (m_llModelInfo.model->implementation().modelType()[0]) { case 'L': m_llModelType = LLModelType::LLAMA_; break; case 'G': m_llModelType = LLModelType::GPTJ_; break; diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 7e0b51eb..82954041 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -127,6 +127,7 @@ Q_SIGNALS: void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); void reportSpeed(const QString &speed); void reportDevice(const QString &device); + void reportFallbackReason(const QString &fallbackReason); void databaseResultsChanged(const QList&); void modelInfoChanged(const ModelInfo &modelInfo); diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml index 304112d1..cabb214d 100644 --- a/gpt4all-chat/main.qml +++ b/gpt4all-chat/main.qml @@ -1013,7 +1013,7 @@ Window { anchors.rightMargin: 30 color: theme.mutedTextColor visible: currentChat.tokenSpeed !== "" - text: qsTr("Speed: ") + currentChat.tokenSpeed + "
" + qsTr("Device: ") + currentChat.device + text: qsTr("Speed: ") + currentChat.tokenSpeed + "
" + qsTr("Device: ") + currentChat.device + currentChat.fallbackReason font.pixelSize: theme.fontSizeLarge }