Enable the force metal setting.

This commit is contained in:
Adam Treat
2023-06-27 11:54:34 -04:00
committed by AT
parent 2565f6a94a
commit 267601d670
9 changed files with 109 additions and 7 deletions

View File

@@ -3,6 +3,7 @@
#include "chatgpt.h"
#include "modellist.h"
#include "network.h"
#include "mysettings.h"
#include "../gpt4all-backend/llmodel.h"
#include <QCoreApplication>
@@ -73,6 +74,8 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_stopGenerating(false)
, m_timer(nullptr)
, m_isServer(isServer)
, m_forceMetal(MySettings::globalInstance()->forceMetal())
, m_reloadingToChangeVariant(false)
{
moveToThread(&m_llmThread);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
@@ -81,6 +84,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged);
// The following are blocking operations and will block the llm thread
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
@@ -110,6 +114,19 @@ void ChatLLM::handleThreadStarted()
emit threadStarted();
}
void ChatLLM::handleForceMetalChanged(bool forceMetal)
{
#if defined(Q_OS_MAC) && defined(__arm__)
m_forceMetal = forceMetal;
if (isModelLoaded() && m_shouldBeLoaded) {
m_reloadingToChangeVariant = true;
unloadModel();
reloadModel();
m_reloadingToChangeVariant = false;
}
#endif
}
bool ChatLLM::loadDefaultModel()
{
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
@@ -154,7 +171,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo3.model;
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
// At this point it is possible that while we were blocked waiting to acquire the model from the
// store, that our state was changed to not be loaded. If this is the case, release the model
@@ -170,7 +187,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
}
// Check if the store just gave us exactly the model we were looking for
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo) {
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) {
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
@@ -210,7 +227,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
model->setAPIKey(apiKey);
m_llModelInfo.model = model;
} else {
m_llModelInfo.model = LLModel::construct(filePath.toStdString());
#if defined(Q_OS_MAC) && defined(__arm__)
if (m_forceMetal)
m_llModelInfo.model = LLModel::construct(filePath.toStdString(), "metal");
else
m_llModelInfo.model = LLModel::construct(filePath.toStdString(), "auto");
#else
m_llModelInfo.model = LLModel::construct(filePath.toStdString(), "auto");
#endif
if (m_llModelInfo.model) {
bool success = m_llModelInfo.model->loadModel(filePath.toStdString());
if (!success) {