diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index cccfff65..666ea17b 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -56,6 +56,7 @@ void Chat::connectLLM() connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::reportDevice, this, &Chat::handleDeviceChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); @@ -345,6 +346,12 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed) emit tokenSpeedChanged(); } +void Chat::handleDeviceChanged(const QString &device) +{ + m_device = device; + emit deviceChanged(); +} + void Chat::handleDatabaseResultsChanged(const QList &results) { m_databaseResults = results; diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index 2751e957..ad7e12b7 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -25,6 +25,7 @@ class Chat : public QObject Q_PROPERTY(QList collectionList READ collectionList NOTIFY collectionListChanged) Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged) Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged); + Q_PROPERTY(QString device READ device NOTIFY deviceChanged); QML_ELEMENT QML_UNCREATABLE("Only creatable from c++!") @@ -88,6 +89,7 @@ public: QString modelLoadingError() const { return m_modelLoadingError; } QString tokenSpeed() const { return m_tokenSpeed; } + QString device() const { return m_device; } public Q_SLOTS: void serverNewPromptResponsePair(const QString &prompt); @@ -115,6 +117,7 @@ Q_SIGNALS: void isServerChanged(); void collectionListChanged(const QList &collectionList); void tokenSpeedChanged(); + void deviceChanged(); private Q_SLOTS: void handleResponseChanged(const QString &response); @@ -125,6 +128,7 @@ private Q_SLOTS: void handleRecalculating(); void handleModelLoadingError(const QString &error); void handleTokenSpeedChanged(const QString &tokenSpeed); + void handleDeviceChanged(const QString &device); void handleDatabaseResultsChanged(const QList &results); void handleModelInfoChanged(const ModelInfo &modelInfo); void handleModelInstalled(); @@ -137,6 +141,7 @@ private: ModelInfo m_modelInfo; QString m_modelLoadingError; QString m_tokenSpeed; + QString m_device; QString m_response; QList m_collections; ChatModel *m_chatModel; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index b10eb6ec..0efc0c71 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -271,22 +271,28 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) MySettings::globalInstance()->setDeviceList(deviceList); // Pick the best match for the device + QString actualDevice = m_llModelInfo.model->implementation().buildVariant() == "metal" ? "Metal" : "CPU"; const QString requestedDevice = MySettings::globalInstance()->device(); if (requestedDevice != "CPU") { const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString()); std::vector availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory); if (!availableDevices.empty() && requestedDevice == "Auto" && availableDevices.front().type == 2 /*a discrete gpu*/) { m_llModelInfo.model->initializeGPUDevice(availableDevices.front()); + actualDevice = QString::fromStdString(availableDevices.front().name); } else { for (LLModel::GPUDevice &d : availableDevices) { if (QString::fromStdString(d.name) == requestedDevice) { m_llModelInfo.model->initializeGPUDevice(d); + actualDevice = QString::fromStdString(d.name); break; } } } } + // Report which device we're actually using + emit reportDevice(actualDevice); + bool success = m_llModelInfo.model->loadModel(filePath.toStdString()); MySettings::globalInstance()->setAttemptModelLoad(QString()); if (!success) { diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index b15bfaa8..724ccefa 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -129,6 +129,7 @@ Q_SIGNALS: void shouldBeLoadedChanged(); void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); void reportSpeed(const QString &speed); + void reportDevice(const QString &device); void databaseResultsChanged(const QList&); void modelInfoChanged(const ModelInfo &modelInfo); diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml index 6be65542..158ddc9a 100644 --- a/gpt4all-chat/main.qml +++ b/gpt4all-chat/main.qml @@ -1013,7 +1013,7 @@ Window { anchors.rightMargin: 30 color: theme.mutedTextColor visible: currentChat.tokenSpeed !== "" - text: qsTr("Speed: ") + currentChat.tokenSpeed + "
" + qsTr("Device: ") + MySettings.device + text: qsTr("Speed: ") + currentChat.tokenSpeed + "
" + qsTr("Device: ") + currentChat.device font.pixelSize: theme.fontSizeLarge }