First attempt at providing a persistent chat list experience.

Limitations:

1) Context is not restored for gpt-j models
2) When you switch between different model types in an existing chat
   the context and all the conversation is lost
3) The settings are not chat or conversation specific
4) The sizes of the chat persisted files are very large due to how much
   data the llama.cpp backend tries to persist. Need to investigate how
   we can shrink this.
This commit is contained in:
Adam Treat 2023-05-04 15:31:41 -04:00
parent 02c9bb4ac7
commit 01e582f15b
19 changed files with 530 additions and 208 deletions

View File

@ -60,7 +60,7 @@ qt_add_executable(chat
main.cpp main.cpp
chat.h chat.cpp chat.h chat.cpp
chatllm.h chatllm.cpp chatllm.h chatllm.cpp
chatmodel.h chatlistmodel.h chatmodel.h chatlistmodel.h chatlistmodel.cpp
download.h download.cpp download.h download.cpp
network.h network.cpp network.h network.cpp
llm.h llm.cpp llm.h llm.cpp

182
chat.cpp
View File

@ -1,32 +1,37 @@
#include "chat.h" #include "chat.h"
#include "llm.h"
#include "network.h" #include "network.h"
#include "download.h"
Chat::Chat(QObject *parent) Chat::Chat(QObject *parent)
: QObject(parent) : QObject(parent)
, m_llmodel(new ChatLLM)
, m_id(Network::globalInstance()->generateUniqueId()) , m_id(Network::globalInstance()->generateUniqueId())
, m_name(tr("New Chat")) , m_name(tr("New Chat"))
, m_chatModel(new ChatModel(this)) , m_chatModel(new ChatModel(this))
, m_responseInProgress(false) , m_responseInProgress(false)
, m_desiredThreadCount(std::min(4, (int32_t) std::thread::hardware_concurrency())) , m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new ChatLLM(this))
{ {
// Should be in same thread
connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
// Should be in different threads
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::responseChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::modelNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::handleModelNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::threadCountChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::syncThreadCount, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::recalcChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
connect(this, &Chat::unloadRequested, m_llmodel, &ChatLLM::unload, Qt::QueuedConnection); connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::reloadRequested, m_llmodel, &ChatLLM::reload, Qt::QueuedConnection); connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection);
connect(this, &Chat::unloadModelRequested, m_llmodel, &ChatLLM::unloadModel, Qt::QueuedConnection);
connect(this, &Chat::reloadModelRequested, m_llmodel, &ChatLLM::reloadModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
connect(this, &Chat::setThreadCountRequested, m_llmodel, &ChatLLM::setThreadCount, Qt::QueuedConnection);
// The following are blocking operations and will block the gui thread, therefore must be fast // The following are blocking operations and will block the gui thread, therefore must be fast
// to respond to // to respond to
@ -38,9 +43,21 @@ Chat::Chat(QObject *parent)
void Chat::reset() void Chat::reset()
{ {
stopGenerating(); stopGenerating();
// Erase our current on disk representation as we're completely resetting the chat along with id
LLM::globalInstance()->chatListModel()->removeChatFile(this);
emit resetContextRequested(); // blocking queued connection emit resetContextRequested(); // blocking queued connection
m_id = Network::globalInstance()->generateUniqueId(); m_id = Network::globalInstance()->generateUniqueId();
emit idChanged(); emit idChanged();
// NOTE: We deliberately do no reset the name or creation date to indictate that this was originally
// an older chat that was reset for another purpose. Resetting this data will lead to the chat
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
// further down in the list. This might surprise the user. In the future, we me might get rid of
// the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown
// we effectively do a reset context. We *have* to do this right now when switching between different
// types of models. The only way to get rid of that would be a very long recalculate where we rebuild
// the context if we switch between different types of models. Probably the right way to fix this
// is to allow switching models but throwing up a dialog warning users if we switch between types
// of models that a long recalculation will ensue.
m_chatModel->clear(); m_chatModel->clear();
} }
@ -49,10 +66,12 @@ bool Chat::isModelLoaded() const
return m_llmodel->isModelLoaded(); return m_llmodel->isModelLoaded();
} }
void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
{ {
emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch,
repeat_penalty, repeat_penalty_tokens, LLM::globalInstance()->threadCount());
} }
void Chat::regenerateResponse() void Chat::regenerateResponse()
@ -70,6 +89,13 @@ QString Chat::response() const
return m_llmodel->response(); return m_llmodel->response();
} }
void Chat::handleResponseChanged()
{
const int index = m_chatModel->count() - 1;
m_chatModel->updateValue(index, response());
emit responseChanged();
}
void Chat::responseStarted() void Chat::responseStarted()
{ {
m_responseInProgress = true; m_responseInProgress = true;
@ -98,21 +124,6 @@ void Chat::setModelName(const QString &modelName)
emit modelNameChangeRequested(modelName); emit modelNameChangeRequested(modelName);
} }
void Chat::syncThreadCount() {
emit setThreadCountRequested(m_desiredThreadCount);
}
void Chat::setThreadCount(int32_t n_threads) {
if (n_threads <= 0)
n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
m_desiredThreadCount = n_threads;
syncThreadCount();
}
int32_t Chat::threadCount() {
return m_llmodel->threadCount();
}
void Chat::newPromptResponsePair(const QString &prompt) void Chat::newPromptResponsePair(const QString &prompt)
{ {
m_chatModel->appendPrompt(tr("Prompt: "), prompt); m_chatModel->appendPrompt(tr("Prompt: "), prompt);
@ -125,16 +136,25 @@ bool Chat::isRecalc() const
return m_llmodel->isRecalc(); return m_llmodel->isRecalc();
} }
void Chat::unload() void Chat::loadDefaultModel()
{ {
m_savedModelName = m_llmodel->modelName(); emit loadDefaultModelRequested();
stopGenerating();
emit unloadRequested();
} }
void Chat::reload() void Chat::loadModel(const QString &modelName)
{ {
emit reloadRequested(m_savedModelName); emit loadModelRequested(modelName);
}
void Chat::unloadModel()
{
stopGenerating();
emit unloadModelRequested();
}
void Chat::reloadModel()
{
emit reloadModelRequested(m_savedModelName);
} }
void Chat::generatedNameChanged() void Chat::generatedNameChanged()
@ -150,4 +170,98 @@ void Chat::generatedNameChanged()
void Chat::handleRecalculating() void Chat::handleRecalculating()
{ {
Network::globalInstance()->sendRecalculatingContext(m_chatModel->count()); Network::globalInstance()->sendRecalculatingContext(m_chatModel->count());
emit recalcChanged();
}
void Chat::handleModelNameChanged()
{
m_savedModelName = modelName();
emit modelNameChanged();
}
bool Chat::serialize(QDataStream &stream) const
{
stream << m_creationDate;
stream << m_id;
stream << m_name;
stream << m_userName;
stream << m_savedModelName;
if (!m_llmodel->serialize(stream))
return false;
if (!m_chatModel->serialize(stream))
return false;
return stream.status() == QDataStream::Ok;
}
bool Chat::deserialize(QDataStream &stream)
{
stream >> m_creationDate;
stream >> m_id;
emit idChanged();
stream >> m_name;
stream >> m_userName;
emit nameChanged();
stream >> m_savedModelName;
if (!m_llmodel->deserialize(stream))
return false;
if (!m_chatModel->deserialize(stream))
return false;
emit chatModelChanged();
return stream.status() == QDataStream::Ok;
}
QList<QString> Chat::modelList() const
{
// Build a model list from exepath and from the localpath
QList<QString> list;
QString exePath = QCoreApplication::applicationDirPath() + QDir::separator();
QString localPath = Download::globalInstance()->downloadLocalModelsPath();
{
QDir dir(exePath);
dir.setNameFilters(QStringList() << "ggml-*.bin");
QStringList fileNames = dir.entryList();
for (QString f : fileNames) {
QString filePath = exePath + f;
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists()) {
if (name == modelName())
list.prepend(name);
else
list.append(name);
}
}
}
if (localPath != exePath) {
QDir dir(localPath);
dir.setNameFilters(QStringList() << "ggml-*.bin");
QStringList fileNames = dir.entryList();
for (QString f : fileNames) {
QString filePath = localPath + f;
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists() && !list.contains(name)) { // don't allow duplicates
if (name == modelName())
list.prepend(name);
else
list.append(name);
}
}
}
if (list.isEmpty()) {
if (exePath != localPath) {
qWarning() << "ERROR: Could not find any applicable models in"
<< exePath << "nor" << localPath;
} else {
qWarning() << "ERROR: Could not find any applicable models in"
<< exePath;
}
return QList<QString>();
}
return list;
} }

42
chat.h
View File

@ -3,6 +3,7 @@
#include <QObject> #include <QObject>
#include <QtQml> #include <QtQml>
#include <QDataStream>
#include "chatllm.h" #include "chatllm.h"
#include "chatmodel.h" #include "chatmodel.h"
@ -17,8 +18,8 @@ class Chat : public QObject
Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(QList<QString> modelList READ modelList NOTIFY modelListChanged)
QML_ELEMENT QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!") QML_UNCREATABLE("Only creatable from c++!")
@ -36,13 +37,10 @@ public:
Q_INVOKABLE void reset(); Q_INVOKABLE void reset();
Q_INVOKABLE bool isModelLoaded() const; Q_INVOKABLE bool isModelLoaded() const;
Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
Q_INVOKABLE void regenerateResponse(); Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void stopGenerating(); Q_INVOKABLE void stopGenerating();
Q_INVOKABLE void syncThreadCount();
Q_INVOKABLE void setThreadCount(int32_t n_threads);
Q_INVOKABLE int32_t threadCount();
Q_INVOKABLE void newPromptResponsePair(const QString &prompt); Q_INVOKABLE void newPromptResponsePair(const QString &prompt);
QString response() const; QString response() const;
@ -51,8 +49,16 @@ public:
void setModelName(const QString &modelName); void setModelName(const QString &modelName);
bool isRecalc() const; bool isRecalc() const;
void unload(); void loadDefaultModel();
void reload(); void loadModel(const QString &modelName);
void unloadModel();
void reloadModel();
qint64 creationDate() const { return m_creationDate; }
bool serialize(QDataStream &stream) const;
bool deserialize(QDataStream &stream);
QList<QString> modelList() const;
Q_SIGNALS: Q_SIGNALS:
void idChanged(); void idChanged();
@ -61,35 +67,39 @@ Q_SIGNALS:
void isModelLoadedChanged(); void isModelLoadedChanged();
void responseChanged(); void responseChanged();
void responseInProgressChanged(); void responseInProgressChanged();
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens,
int32_t n_threads);
void regenerateResponseRequested(); void regenerateResponseRequested();
void resetResponseRequested(); void resetResponseRequested();
void resetContextRequested(); void resetContextRequested();
void modelNameChangeRequested(const QString &modelName); void modelNameChangeRequested(const QString &modelName);
void modelNameChanged(); void modelNameChanged();
void threadCountChanged();
void setThreadCountRequested(int32_t threadCount);
void recalcChanged(); void recalcChanged();
void unloadRequested(); void loadDefaultModelRequested();
void reloadRequested(const QString &modelName); void loadModelRequested(const QString &modelName);
void unloadModelRequested();
void reloadModelRequested(const QString &modelName);
void generateNameRequested(); void generateNameRequested();
void modelListChanged();
private Q_SLOTS: private Q_SLOTS:
void handleResponseChanged();
void responseStarted(); void responseStarted();
void responseStopped(); void responseStopped();
void generatedNameChanged(); void generatedNameChanged();
void handleRecalculating(); void handleRecalculating();
void handleModelNameChanged();
private: private:
ChatLLM *m_llmodel;
QString m_id; QString m_id;
QString m_name; QString m_name;
QString m_userName; QString m_userName;
QString m_savedModelName; QString m_savedModelName;
ChatModel *m_chatModel; ChatModel *m_chatModel;
bool m_responseInProgress; bool m_responseInProgress;
int32_t m_desiredThreadCount; qint64 m_creationDate;
ChatLLM *m_llmodel;
}; };
#endif // CHAT_H #endif // CHAT_H

72
chatlistmodel.cpp Normal file
View File

@ -0,0 +1,72 @@
#include "chatlistmodel.h"
#include <QFile>
#include <QDataStream>
void ChatListModel::removeChatFile(Chat *chat) const
{
QSettings settings;
QFileInfo settingsInfo(settings.fileName());
QString settingsPath = settingsInfo.absolutePath();
QFile file(settingsPath + "/gpt4all-" + chat->id() + ".chat");
if (!file.exists())
return;
bool success = file.remove();
if (!success)
qWarning() << "ERROR: Couldn't remove chat file:" << file.fileName();
}
void ChatListModel::saveChats() const
{
QSettings settings;
QFileInfo settingsInfo(settings.fileName());
QString settingsPath = settingsInfo.absolutePath();
for (Chat *chat : m_chats) {
QFile file(settingsPath + "/gpt4all-" + chat->id() + ".chat");
bool success = file.open(QIODevice::WriteOnly);
if (!success) {
qWarning() << "ERROR: Couldn't save chat to file:" << file.fileName();
continue;
}
QDataStream out(&file);
if (!chat->serialize(out)) {
qWarning() << "ERROR: Couldn't serialize chat to file:" << file.fileName();
file.remove();
}
file.close();
}
}
void ChatListModel::restoreChats()
{
QSettings settings;
QFileInfo settingsInfo(settings.fileName());
QString settingsPath = settingsInfo.absolutePath();
QDir dir(settingsPath);
dir.setNameFilters(QStringList() << "gpt4all-*.chat");
QStringList fileNames = dir.entryList();
beginResetModel();
for (QString f : fileNames) {
QString filePath = settingsPath + "/" + f;
QFile file(filePath);
bool success = file.open(QIODevice::ReadOnly);
if (!success) {
qWarning() << "ERROR: Couldn't restore chat from file:" << file.fileName();
continue;
}
QDataStream in(&file);
Chat *chat = new Chat(this);
if (!chat->deserialize(in)) {
qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName();
file.remove();
} else {
connect(chat, &Chat::nameChanged, this, &ChatListModel::nameChanged);
m_chats.append(chat);
}
file.close();
}
std::sort(m_chats.begin(), m_chats.end(), [](const Chat* a, const Chat* b) {
return a->creationDate() > b->creationDate();
});
endResetModel();
}

View File

@ -55,7 +55,7 @@ public:
Q_INVOKABLE void addChat() Q_INVOKABLE void addChat()
{ {
// Don't add a new chat if the current chat is empty // Don't add a new chat if we already have one
if (m_newChat) if (m_newChat)
return; return;
@ -73,13 +73,29 @@ public:
setCurrentChat(m_newChat); setCurrentChat(m_newChat);
} }
void setNewChat(Chat* chat)
{
// Don't add a new chat if we already have one
if (m_newChat)
return;
m_newChat = chat;
connect(m_newChat->chatModel(), &ChatModel::countChanged,
this, &ChatListModel::newChatCountChanged);
connect(m_newChat, &Chat::nameChanged,
this, &ChatListModel::nameChanged);
setCurrentChat(m_newChat);
}
Q_INVOKABLE void removeChat(Chat* chat) Q_INVOKABLE void removeChat(Chat* chat)
{ {
if (!m_chats.contains(chat)) { if (!m_chats.contains(chat)) {
qDebug() << "WARNING: Removing chat failed with id" << chat->id(); qWarning() << "WARNING: Removing chat failed with id" << chat->id();
return; return;
} }
removeChatFile(chat);
emit disconnectChat(chat); emit disconnectChat(chat);
if (chat == m_newChat) { if (chat == m_newChat) {
m_newChat->disconnect(this); m_newChat->disconnect(this);
@ -115,20 +131,20 @@ public:
void setCurrentChat(Chat *chat) void setCurrentChat(Chat *chat)
{ {
if (!m_chats.contains(chat)) { if (!m_chats.contains(chat)) {
qDebug() << "ERROR: Setting current chat failed with id" << chat->id(); qWarning() << "ERROR: Setting current chat failed with id" << chat->id();
return; return;
} }
if (m_currentChat) { if (m_currentChat) {
if (m_currentChat->isModelLoaded()) if (m_currentChat->isModelLoaded())
m_currentChat->unload(); m_currentChat->unloadModel();
emit disconnect(m_currentChat); emit disconnect(m_currentChat);
} }
emit connectChat(chat); emit connectChat(chat);
m_currentChat = chat; m_currentChat = chat;
if (!m_currentChat->isModelLoaded()) if (!m_currentChat->isModelLoaded())
m_currentChat->reload(); m_currentChat->reloadModel();
emit currentChatChanged(); emit currentChatChanged();
} }
@ -138,9 +154,12 @@ public:
return m_chats.at(index); return m_chats.at(index);
} }
int count() const { return m_chats.size(); } int count() const { return m_chats.size(); }
void removeChatFile(Chat *chat) const;
void saveChats() const;
void restoreChats();
Q_SIGNALS: Q_SIGNALS:
void countChanged(); void countChanged();
void connectChat(Chat*); void connectChat(Chat*);

View File

@ -1,7 +1,7 @@
#include "chatllm.h" #include "chatllm.h"
#include "chat.h"
#include "download.h" #include "download.h"
#include "network.h" #include "network.h"
#include "llm.h"
#include "llmodel/gptj.h" #include "llmodel/gptj.h"
#include "llmodel/llamamodel.h" #include "llmodel/llamamodel.h"
@ -32,28 +32,29 @@ static QString modelFilePath(const QString &modelName)
return QString(); return QString();
} }
ChatLLM::ChatLLM() ChatLLM::ChatLLM(Chat *parent)
: QObject{nullptr} : QObject{nullptr}
, m_llmodel(nullptr) , m_llmodel(nullptr)
, m_promptResponseTokens(0) , m_promptResponseTokens(0)
, m_responseLogits(0) , m_responseLogits(0)
, m_isRecalc(false) , m_isRecalc(false)
, m_chat(parent)
{ {
moveToThread(&m_llmThread); moveToThread(&m_llmThread);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::loadModel);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
m_llmThread.setObjectName("llm thread"); // FIXME: Should identify these with chat name connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
m_llmThread.setObjectName(m_chat->id());
m_llmThread.start(); m_llmThread.start();
} }
bool ChatLLM::loadModel() bool ChatLLM::loadDefaultModel()
{ {
const QList<QString> models = LLM::globalInstance()->modelList(); const QList<QString> models = m_chat->modelList();
if (models.isEmpty()) { if (models.isEmpty()) {
// try again when we get a list of models // try again when we get a list of models
connect(Download::globalInstance(), &Download::modelListChanged, this, connect(Download::globalInstance(), &Download::modelListChanged, this,
&ChatLLM::loadModel, Qt::SingleShotConnection); &ChatLLM::loadDefaultModel, Qt::SingleShotConnection);
return false; return false;
} }
@ -62,10 +63,10 @@ bool ChatLLM::loadModel()
QString defaultModel = settings.value("defaultModel", "gpt4all-j-v1.3-groovy").toString(); QString defaultModel = settings.value("defaultModel", "gpt4all-j-v1.3-groovy").toString();
if (defaultModel.isEmpty() || !models.contains(defaultModel)) if (defaultModel.isEmpty() || !models.contains(defaultModel))
defaultModel = models.first(); defaultModel = models.first();
return loadModelPrivate(defaultModel); return loadModel(defaultModel);
} }
bool ChatLLM::loadModelPrivate(const QString &modelName) bool ChatLLM::loadModel(const QString &modelName)
{ {
if (isModelLoaded() && m_modelName == modelName) if (isModelLoaded() && m_modelName == modelName)
return true; return true;
@ -100,12 +101,13 @@ bool ChatLLM::loadModelPrivate(const QString &modelName)
} }
emit isModelLoadedChanged(); emit isModelLoadedChanged();
emit threadCountChanged();
if (isFirstLoad) if (isFirstLoad)
emit sendStartup(); emit sendStartup();
else else
emit sendModelLoaded(); emit sendModelLoaded();
} else {
qWarning() << "ERROR: Could not find model at" << filePath;
} }
if (m_llmodel) if (m_llmodel)
@ -114,19 +116,6 @@ bool ChatLLM::loadModelPrivate(const QString &modelName)
return m_llmodel; return m_llmodel;
} }
void ChatLLM::setThreadCount(int32_t n_threads) {
if (m_llmodel && m_llmodel->threadCount() != n_threads) {
m_llmodel->setThreadCount(n_threads);
emit threadCountChanged();
}
}
int32_t ChatLLM::threadCount() {
if (!m_llmodel)
return 1;
return m_llmodel->threadCount();
}
bool ChatLLM::isModelLoaded() const bool ChatLLM::isModelLoaded() const
{ {
return m_llmodel && m_llmodel->isModelLoaded(); return m_llmodel && m_llmodel->isModelLoaded();
@ -203,7 +192,7 @@ void ChatLLM::setModelName(const QString &modelName)
void ChatLLM::modelNameChangeRequested(const QString &modelName) void ChatLLM::modelNameChangeRequested(const QString &modelName)
{ {
if (!loadModelPrivate(modelName)) if (!loadModel(modelName))
qWarning() << "ERROR: Could not load model" << modelName; qWarning() << "ERROR: Could not load model" << modelName;
} }
@ -247,8 +236,8 @@ bool ChatLLM::handleRecalculate(bool isRecalc)
return !m_stopGenerating; return !m_stopGenerating;
} }
bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads)
{ {
if (!isModelLoaded()) if (!isModelLoaded())
return false; return false;
@ -269,6 +258,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
m_ctx.n_batch = n_batch; m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty; m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens; m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llmodel->setThreadCount(n_threads);
#if defined(DEBUG) #if defined(DEBUG)
printf("%s", qPrintable(instructPrompt)); printf("%s", qPrintable(instructPrompt));
fflush(stdout); fflush(stdout);
@ -288,19 +278,22 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
return true; return true;
} }
void ChatLLM::unload() void ChatLLM::unloadModel()
{ {
saveState();
delete m_llmodel; delete m_llmodel;
m_llmodel = nullptr; m_llmodel = nullptr;
emit isModelLoadedChanged(); emit isModelLoadedChanged();
} }
void ChatLLM::reload(const QString &modelName) void ChatLLM::reloadModel(const QString &modelName)
{ {
if (modelName.isEmpty()) if (modelName.isEmpty()) {
loadModel(); loadDefaultModel();
else } else {
loadModelPrivate(modelName); loadModel(modelName);
}
restoreState();
} }
void ChatLLM::generateName() void ChatLLM::generateName()
@ -333,6 +326,11 @@ void ChatLLM::generateName()
} }
} }
void ChatLLM::handleChatIdChanged()
{
m_llmThread.setObjectName(m_chat->id());
}
bool ChatLLM::handleNamePrompt(int32_t token) bool ChatLLM::handleNamePrompt(int32_t token)
{ {
Q_UNUSED(token); Q_UNUSED(token);
@ -354,3 +352,60 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc)
Q_UNREACHABLE(); Q_UNREACHABLE();
return true; return true;
} }
bool ChatLLM::serialize(QDataStream &stream)
{
stream << response();
stream << generatedName();
stream << m_promptResponseTokens;
stream << m_responseLogits;
stream << m_ctx.n_past;
stream << quint64(m_ctx.logits.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float));
stream << quint64(m_ctx.tokens.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int));
saveState();
stream << m_state;
return stream.status() == QDataStream::Ok;
}
bool ChatLLM::deserialize(QDataStream &stream)
{
QString response;
stream >> response;
m_response = response.toStdString();
QString nameResponse;
stream >> nameResponse;
m_nameResponse = nameResponse.toStdString();
stream >> m_promptResponseTokens;
stream >> m_responseLogits;
stream >> m_ctx.n_past;
quint64 logitsSize;
stream >> logitsSize;
m_ctx.logits.resize(logitsSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
quint64 tokensSize;
stream >> tokensSize;
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
stream >> m_state;
return stream.status() == QDataStream::Ok;
}
void ChatLLM::saveState()
{
if (!isModelLoaded())
return;
const size_t stateSize = m_llmodel->stateSize();
m_state.resize(stateSize);
m_llmodel->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
}
void ChatLLM::restoreState()
{
if (!isModelLoaded())
return;
m_llmodel->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
}

View File

@ -6,18 +6,18 @@
#include "llmodel/llmodel.h" #include "llmodel/llmodel.h"
class Chat;
class ChatLLM : public QObject class ChatLLM : public QObject
{ {
Q_OBJECT Q_OBJECT
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged) Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged)
public: public:
ChatLLM(); ChatLLM(Chat *parent);
bool isModelLoaded() const; bool isModelLoaded() const;
void regenerateResponse(); void regenerateResponse();
@ -25,8 +25,6 @@ public:
void resetContext(); void resetContext();
void stopGenerating() { m_stopGenerating = true; } void stopGenerating() { m_stopGenerating = true; }
void setThreadCount(int32_t n_threads);
int32_t threadCount();
QString response() const; QString response() const;
QString modelName() const; QString modelName() const;
@ -37,14 +35,20 @@ public:
QString generatedName() const { return QString::fromStdString(m_nameResponse); } QString generatedName() const { return QString::fromStdString(m_nameResponse); }
bool serialize(QDataStream &stream);
bool deserialize(QDataStream &stream);
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens,
bool loadModel(); int32_t n_threads);
bool loadDefaultModel();
bool loadModel(const QString &modelName);
void modelNameChangeRequested(const QString &modelName); void modelNameChangeRequested(const QString &modelName);
void unload(); void unloadModel();
void reload(const QString &modelName); void reloadModel(const QString &modelName);
void generateName(); void generateName();
void handleChatIdChanged();
Q_SIGNALS: Q_SIGNALS:
void isModelLoadedChanged(); void isModelLoadedChanged();
@ -52,22 +56,23 @@ Q_SIGNALS:
void responseStarted(); void responseStarted();
void responseStopped(); void responseStopped();
void modelNameChanged(); void modelNameChanged();
void threadCountChanged();
void recalcChanged(); void recalcChanged();
void sendStartup(); void sendStartup();
void sendModelLoaded(); void sendModelLoaded();
void sendResetContext(); void sendResetContext();
void generatedNameChanged(); void generatedNameChanged();
void stateChanged();
private: private:
void resetContextPrivate(); void resetContextPrivate();
bool loadModelPrivate(const QString &modelName);
bool handlePrompt(int32_t token); bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response); bool handleResponse(int32_t token, const std::string &response);
bool handleRecalculate(bool isRecalc); bool handleRecalculate(bool isRecalc);
bool handleNamePrompt(int32_t token); bool handleNamePrompt(int32_t token);
bool handleNameResponse(int32_t token, const std::string &response); bool handleNameResponse(int32_t token, const std::string &response);
bool handleNameRecalculate(bool isRecalc); bool handleNameRecalculate(bool isRecalc);
void saveState();
void restoreState();
private: private:
LLModel::PromptContext m_ctx; LLModel::PromptContext m_ctx;
@ -77,6 +82,8 @@ private:
quint32 m_promptResponseTokens; quint32 m_promptResponseTokens;
quint32 m_responseLogits; quint32 m_responseLogits;
QString m_modelName; QString m_modelName;
Chat *m_chat;
QByteArray m_state;
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;
bool m_isRecalc; bool m_isRecalc;

View File

@ -3,6 +3,7 @@
#include <QAbstractListModel> #include <QAbstractListModel>
#include <QtQml> #include <QtQml>
#include <QDataStream>
struct ChatItem struct ChatItem
{ {
@ -209,6 +210,46 @@ public:
int count() const { return m_chatItems.size(); } int count() const { return m_chatItems.size(); }
bool serialize(QDataStream &stream) const
{
stream << count();
for (auto c : m_chatItems) {
stream << c.id;
stream << c.name;
stream << c.value;
stream << c.prompt;
stream << c.newResponse;
stream << c.currentResponse;
stream << c.stopped;
stream << c.thumbsUpState;
stream << c.thumbsDownState;
}
return stream.status() == QDataStream::Ok;
}
bool deserialize(QDataStream &stream)
{
int size;
stream >> size;
for (int i = 0; i < size; ++i) {
ChatItem c;
stream >> c.id;
stream >> c.name;
stream >> c.value;
stream >> c.prompt;
stream >> c.newResponse;
stream >> c.currentResponse;
stream >> c.stopped;
stream >> c.thumbsUpState;
stream >> c.thumbsDownState;
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
m_chatItems.append(c);
endInsertRows();
}
emit countChanged();
return stream.status() == QDataStream::Ok;
}
Q_SIGNALS: Q_SIGNALS:
void countChanged(); void countChanged();

96
llm.cpp
View File

@ -20,77 +20,22 @@ LLM *LLM::globalInstance()
LLM::LLM() LLM::LLM()
: QObject{nullptr} : QObject{nullptr}
, m_chatListModel(new ChatListModel(this)) , m_chatListModel(new ChatListModel(this))
, m_threadCount(std::min(4, (int32_t) std::thread::hardware_concurrency()))
{ {
// Should be in the same thread connect(QCoreApplication::instance(), &QCoreApplication::aboutToQuit,
connect(Download::globalInstance(), &Download::modelListChanged, this, &LLM::aboutToQuit);
this, &LLM::modelListChanged, Qt::DirectConnection);
connect(m_chatListModel, &ChatListModel::connectChat,
this, &LLM::connectChat, Qt::DirectConnection);
connect(m_chatListModel, &ChatListModel::disconnectChat,
this, &LLM::disconnectChat, Qt::DirectConnection);
if (!m_chatListModel->count()) m_chatListModel->restoreChats();
if (m_chatListModel->count()) {
Chat *firstChat = m_chatListModel->get(0);
if (firstChat->chatModel()->count() < 2)
m_chatListModel->setNewChat(firstChat);
else
m_chatListModel->setCurrentChat(firstChat);
} else
m_chatListModel->addChat(); m_chatListModel->addChat();
} }
QList<QString> LLM::modelList() const
{
Q_ASSERT(m_chatListModel->currentChat());
const Chat *currentChat = m_chatListModel->currentChat();
// Build a model list from exepath and from the localpath
QList<QString> list;
QString exePath = QCoreApplication::applicationDirPath() + QDir::separator();
QString localPath = Download::globalInstance()->downloadLocalModelsPath();
{
QDir dir(exePath);
dir.setNameFilters(QStringList() << "ggml-*.bin");
QStringList fileNames = dir.entryList();
for (QString f : fileNames) {
QString filePath = exePath + f;
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists()) {
if (name == currentChat->modelName())
list.prepend(name);
else
list.append(name);
}
}
}
if (localPath != exePath) {
QDir dir(localPath);
dir.setNameFilters(QStringList() << "ggml-*.bin");
QStringList fileNames = dir.entryList();
for (QString f : fileNames) {
QString filePath = localPath + f;
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists() && !list.contains(name)) { // don't allow duplicates
if (name == currentChat->modelName())
list.prepend(name);
else
list.append(name);
}
}
}
if (list.isEmpty()) {
if (exePath != localPath) {
qWarning() << "ERROR: Could not find any applicable models in"
<< exePath << "nor" << localPath;
} else {
qWarning() << "ERROR: Could not find any applicable models in"
<< exePath;
}
return QList<QString>();
}
return list;
}
bool LLM::checkForUpdates() const bool LLM::checkForUpdates() const
{ {
Network::globalInstance()->sendCheckForUpdates(); Network::globalInstance()->sendCheckForUpdates();
@ -113,21 +58,20 @@ bool LLM::checkForUpdates() const
return QProcess::startDetached(fileName); return QProcess::startDetached(fileName);
} }
bool LLM::isRecalc() const int32_t LLM::threadCount() const
{ {
Q_ASSERT(m_chatListModel->currentChat()); return m_threadCount;
return m_chatListModel->currentChat()->isRecalc();
} }
void LLM::connectChat(Chat *chat) void LLM::setThreadCount(int32_t n_threads)
{ {
// Should be in the same thread if (n_threads <= 0)
connect(chat, &Chat::modelNameChanged, this, &LLM::modelListChanged, Qt::DirectConnection); n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
connect(chat, &Chat::recalcChanged, this, &LLM::recalcChanged, Qt::DirectConnection); m_threadCount = n_threads;
connect(chat, &Chat::responseChanged, this, &LLM::responseChanged, Qt::DirectConnection); emit threadCountChanged();
} }
void LLM::disconnectChat(Chat *chat) void LLM::aboutToQuit()
{ {
chat->disconnect(this); m_chatListModel->saveChats();
} }

16
llm.h
View File

@ -3,37 +3,33 @@
#include <QObject> #include <QObject>
#include "chat.h"
#include "chatlistmodel.h" #include "chatlistmodel.h"
class LLM : public QObject class LLM : public QObject
{ {
Q_OBJECT Q_OBJECT
Q_PROPERTY(QList<QString> modelList READ modelList NOTIFY modelListChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(ChatListModel *chatListModel READ chatListModel NOTIFY chatListModelChanged) Q_PROPERTY(ChatListModel *chatListModel READ chatListModel NOTIFY chatListModelChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
public: public:
static LLM *globalInstance(); static LLM *globalInstance();
QList<QString> modelList() const;
bool isRecalc() const;
ChatListModel *chatListModel() const { return m_chatListModel; } ChatListModel *chatListModel() const { return m_chatListModel; }
int32_t threadCount() const;
void setThreadCount(int32_t n_threads);
Q_INVOKABLE bool checkForUpdates() const; Q_INVOKABLE bool checkForUpdates() const;
Q_SIGNALS: Q_SIGNALS:
void modelListChanged();
void recalcChanged();
void responseChanged();
void chatListModelChanged(); void chatListModelChanged();
void threadCountChanged();
private Q_SLOTS: private Q_SLOTS:
void connectChat(Chat*); void aboutToQuit();
void disconnectChat(Chat*);
private: private:
ChatListModel *m_chatListModel; ChatListModel *m_chatListModel;
int32_t m_threadCount;
private: private:
explicit LLM(); explicit LLM();

View File

@ -67,6 +67,7 @@ int32_t LLamaModel::threadCount() {
LLamaModel::~LLamaModel() LLamaModel::~LLamaModel()
{ {
llama_free(d_ptr->ctx);
} }
bool LLamaModel::isModelLoaded() const bool LLamaModel::isModelLoaded() const
@ -74,6 +75,21 @@ bool LLamaModel::isModelLoaded() const
return d_ptr->modelLoaded; return d_ptr->modelLoaded;
} }
size_t LLamaModel::stateSize() const
{
return llama_get_state_size(d_ptr->ctx);
}
size_t LLamaModel::saveState(uint8_t *dest) const
{
return llama_copy_state_data(d_ptr->ctx, dest);
}
size_t LLamaModel::restoreState(const uint8_t *src)
{
return llama_set_state_data(d_ptr->ctx, src);
}
void LLamaModel::prompt(const std::string &prompt, void LLamaModel::prompt(const std::string &prompt,
std::function<bool(int32_t)> promptCallback, std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback, std::function<bool(int32_t, const std::string&)> responseCallback,

View File

@ -14,6 +14,9 @@ public:
bool loadModel(const std::string &modelPath) override; bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override; bool isModelLoaded() const override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;
void prompt(const std::string &prompt, void prompt(const std::string &prompt,
std::function<bool(int32_t)> promptCallback, std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback, std::function<bool(int32_t, const std::string&)> responseCallback,

View File

@ -12,6 +12,9 @@ public:
virtual bool loadModel(const std::string &modelPath) = 0; virtual bool loadModel(const std::string &modelPath) = 0;
virtual bool isModelLoaded() const = 0; virtual bool isModelLoaded() const = 0;
virtual size_t stateSize() const { return 0; }
virtual size_t saveState(uint8_t *dest) const { return 0; }
virtual size_t restoreState(const uint8_t *src) { return 0; }
struct PromptContext { struct PromptContext {
std::vector<float> logits; // logits of current context std::vector<float> logits; // logits of current context
std::vector<int32_t> tokens; // current tokens in the context window std::vector<int32_t> tokens; // current tokens in the context window

View File

@ -48,6 +48,24 @@ bool llmodel_isModelLoaded(llmodel_model model)
return wrapper->llModel->isModelLoaded(); return wrapper->llModel->isModelLoaded();
} }
uint64_t llmodel_get_state_size(llmodel_model model)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->stateSize();
}
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->saveState(dest);
}
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->restoreState(src);
}
// Wrapper functions for the C callbacks // Wrapper functions for the C callbacks
bool prompt_wrapper(int32_t token_id, void *user_data) { bool prompt_wrapper(int32_t token_id, void *user_data) {
llmodel_prompt_callback callback = reinterpret_cast<llmodel_prompt_callback>(user_data); llmodel_prompt_callback callback = reinterpret_cast<llmodel_prompt_callback>(user_data);

View File

@ -98,6 +98,32 @@ bool llmodel_loadModel(llmodel_model model, const char *model_path);
*/ */
bool llmodel_isModelLoaded(llmodel_model model); bool llmodel_isModelLoaded(llmodel_model model);
/**
* Get the size of the internal state of the model.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @return the size in bytes of the internal state of the model
*/
uint64_t llmodel_get_state_size(llmodel_model model);
/**
* Saves the internal state of the model to the specified destination address.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param dest A pointer to the destination.
* @return the number of bytes copied
*/
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest);
/**
* Restores the internal state of the model using data from the specified address.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param src A pointer to the src.
* @return the number of bytes read
*/
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
/** /**
* Generate a response using the model. * Generate a response using the model.
* @param model A pointer to the llmodel_model instance. * @param model A pointer to the llmodel_model instance.

View File

@ -65,7 +65,7 @@ Window {
} }
// check for any current models and if not, open download dialog // check for any current models and if not, open download dialog
if (LLM.modelList.length === 0 && !firstStartDialog.opened) { if (currentChat.modelList.length === 0 && !firstStartDialog.opened) {
downloadNewModels.open(); downloadNewModels.open();
return; return;
} }
@ -125,7 +125,7 @@ Window {
anchors.horizontalCenter: parent.horizontalCenter anchors.horizontalCenter: parent.horizontalCenter
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
spacing: 0 spacing: 0
model: LLM.modelList model: currentChat.modelList
Accessible.role: Accessible.ComboBox Accessible.role: Accessible.ComboBox
Accessible.name: qsTr("ComboBox for displaying/picking the current model") Accessible.name: qsTr("ComboBox for displaying/picking the current model")
Accessible.description: qsTr("Use this for picking the current model to use; the first item is the current model") Accessible.description: qsTr("Use this for picking the current model to use; the first item is the current model")
@ -367,9 +367,9 @@ Window {
text: qsTr("Recalculating context.") text: qsTr("Recalculating context.")
Connections { Connections {
target: LLM target: currentChat
function onRecalcChanged() { function onRecalcChanged() {
if (LLM.isRecalc) if (currentChat.isRecalc)
recalcPopup.open() recalcPopup.open()
else else
recalcPopup.close() recalcPopup.close()
@ -422,9 +422,6 @@ Window {
var item = chatModel.get(i) var item = chatModel.get(i)
var string = item.name; var string = item.name;
var isResponse = item.name === qsTr("Response: ") var isResponse = item.name === qsTr("Response: ")
if (item.currentResponse)
string += currentChat.response
else
string += chatModel.get(i).value string += chatModel.get(i).value
if (isResponse && item.stopped) if (isResponse && item.stopped)
string += " <stopped>" string += " <stopped>"
@ -440,9 +437,6 @@ Window {
var item = chatModel.get(i) var item = chatModel.get(i)
var isResponse = item.name === qsTr("Response: ") var isResponse = item.name === qsTr("Response: ")
str += "{\"content\": "; str += "{\"content\": ";
if (item.currentResponse)
str += JSON.stringify(currentChat.response)
else
str += JSON.stringify(item.value) str += JSON.stringify(item.value)
str += ", \"role\": \"" + (isResponse ? "assistant" : "user") + "\""; str += ", \"role\": \"" + (isResponse ? "assistant" : "user") + "\"";
if (isResponse && item.thumbsUpState !== item.thumbsDownState) if (isResponse && item.thumbsUpState !== item.thumbsDownState)
@ -572,14 +566,14 @@ Window {
Accessible.description: qsTr("This is the list of prompt/response pairs comprising the actual conversation with the model") Accessible.description: qsTr("This is the list of prompt/response pairs comprising the actual conversation with the model")
delegate: TextArea { delegate: TextArea {
text: currentResponse ? currentChat.response : (value ? value : "") text: value
width: listView.width width: listView.width
color: theme.textColor color: theme.textColor
wrapMode: Text.WordWrap wrapMode: Text.WordWrap
focus: false focus: false
readOnly: true readOnly: true
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
cursorVisible: currentResponse ? (currentChat.response !== "" ? currentChat.responseInProgress : false) : false cursorVisible: currentResponse ? currentChat.responseInProgress : false
cursorPosition: text.length cursorPosition: text.length
background: Rectangle { background: Rectangle {
color: name === qsTr("Response: ") ? theme.backgroundLighter : theme.backgroundLight color: name === qsTr("Response: ") ? theme.backgroundLighter : theme.backgroundLight
@ -599,8 +593,8 @@ Window {
anchors.leftMargin: 90 anchors.leftMargin: 90
anchors.top: parent.top anchors.top: parent.top
anchors.topMargin: 5 anchors.topMargin: 5
visible: (currentResponse ? true : false) && currentChat.response === "" && currentChat.responseInProgress visible: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress
running: (currentResponse ? true : false) && currentChat.response === "" && currentChat.responseInProgress running: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress
Accessible.role: Accessible.Animation Accessible.role: Accessible.Animation
Accessible.name: qsTr("Busy indicator") Accessible.name: qsTr("Busy indicator")
@ -631,7 +625,7 @@ Window {
window.height / 2 - height / 2) window.height / 2 - height / 2)
x: globalPoint.x x: globalPoint.x
y: globalPoint.y y: globalPoint.y
property string text: currentResponse ? currentChat.response : (value ? value : "") property string text: value
response: newResponse === undefined || newResponse === "" ? text : newResponse response: newResponse === undefined || newResponse === "" ? text : newResponse
onAccepted: { onAccepted: {
var responseHasChanged = response !== text && response !== newResponse var responseHasChanged = response !== text && response !== newResponse
@ -711,7 +705,7 @@ Window {
property bool isAutoScrolling: false property bool isAutoScrolling: false
Connections { Connections {
target: LLM target: currentChat
function onResponseChanged() { function onResponseChanged() {
if (listView.shouldAutoScroll) { if (listView.shouldAutoScroll) {
listView.isAutoScrolling = true listView.isAutoScrolling = true
@ -762,7 +756,6 @@ Window {
if (listElement.name === qsTr("Response: ")) { if (listElement.name === qsTr("Response: ")) {
chatModel.updateCurrentResponse(index, true); chatModel.updateCurrentResponse(index, true);
chatModel.updateStopped(index, false); chatModel.updateStopped(index, false);
chatModel.updateValue(index, currentChat.response);
chatModel.updateThumbsUpState(index, false); chatModel.updateThumbsUpState(index, false);
chatModel.updateThumbsDownState(index, false); chatModel.updateThumbsDownState(index, false);
chatModel.updateNewResponse(index, ""); chatModel.updateNewResponse(index, "");
@ -840,7 +833,6 @@ Window {
var index = Math.max(0, chatModel.count - 1); var index = Math.max(0, chatModel.count - 1);
var listElement = chatModel.get(index); var listElement = chatModel.get(index);
chatModel.updateCurrentResponse(index, false); chatModel.updateCurrentResponse(index, false);
chatModel.updateValue(index, currentChat.response);
} }
currentChat.newPromptResponsePair(textInput.text); currentChat.newPromptResponsePair(textInput.text);
currentChat.prompt(textInput.text, settingsDialog.promptTemplate, currentChat.prompt(textInput.text, settingsDialog.promptTemplate,

View File

@ -458,7 +458,6 @@ void Network::handleIpifyFinished()
void Network::handleMixpanelFinished() void Network::handleMixpanelFinished()
{ {
Q_ASSERT(m_usageStatsActive);
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender()); QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply) if (!reply)
return; return;

View File

@ -83,6 +83,7 @@ Drawer {
opacity: 0.9 opacity: 0.9
property bool isCurrent: LLM.chatListModel.currentChat === LLM.chatListModel.get(index) property bool isCurrent: LLM.chatListModel.currentChat === LLM.chatListModel.get(index)
property bool trashQuestionDisplayed: false property bool trashQuestionDisplayed: false
z: isCurrent ? 199 : 1
color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter
border.width: isCurrent border.width: isCurrent
border.color: chatName.readOnly ? theme.assistantColor : theme.userColor border.color: chatName.readOnly ? theme.assistantColor : theme.userColor
@ -112,6 +113,11 @@ Drawer {
color: "transparent" color: "transparent"
} }
onEditingFinished: { onEditingFinished: {
// Work around a bug in qml where we're losing focus when the whole window
// goes out of focus even though this textfield should be marked as not
// having focus
if (chatName.readOnly)
return;
changeName(); changeName();
Network.sendRenameChat() Network.sendRenameChat()
} }
@ -188,6 +194,7 @@ Drawer {
visible: isCurrent && trashQuestionDisplayed visible: isCurrent && trashQuestionDisplayed
opacity: 1.0 opacity: 1.0
radius: 10 radius: 10
z: 200
Row { Row {
spacing: 10 spacing: 10
Button { Button {

View File

@ -12,7 +12,7 @@ Dialog {
id: modelDownloaderDialog id: modelDownloaderDialog
modal: true modal: true
opacity: 0.9 opacity: 0.9
closePolicy: LLM.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside) closePolicy: LLM.chatListModel.currentChat.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside)
background: Rectangle { background: Rectangle {
anchors.fill: parent anchors.fill: parent
anchors.margins: -20 anchors.margins: -20