WIP: need to run rr on a real computer since this bug is confusing

This commit is contained in:
Jared Van Bortel 2025-03-17 13:11:55 -04:00
parent 371971e6ac
commit 8294a5cd58
18 changed files with 500 additions and 276 deletions

1
deps/CMakeLists.txt vendored
View File

@ -8,7 +8,6 @@ set(QCORO_BUILD_EXAMPLES OFF)
set(QCORO_WITH_QTDBUS OFF) set(QCORO_WITH_QTDBUS OFF)
set(QCORO_WITH_QTWEBSOCKETS OFF) set(QCORO_WITH_QTWEBSOCKETS OFF)
set(QCORO_WITH_QTQUICK OFF) set(QCORO_WITH_QTQUICK OFF)
set(QCORO_WITH_QML OFF)
set(QCORO_WITH_QTTEST OFF) set(QCORO_WITH_QTTEST OFF)
add_subdirectory(qcoro) add_subdirectory(qcoro)

View File

@ -249,7 +249,7 @@ qt_add_executable(chat
src/localdocs.cpp src/localdocs.h src/localdocs.cpp src/localdocs.h
src/localdocsmodel.cpp src/localdocsmodel.h src/localdocsmodel.cpp src/localdocsmodel.h
src/logger.cpp src/logger.h src/logger.cpp src/logger.h
src/main.cpp src/main.cpp src/main.h
src/modellist.cpp src/modellist.h src/modellist.cpp src/modellist.h
src/mysettings.cpp src/mysettings.h src/mysettings.cpp src/mysettings.h
src/network.cpp src/network.h src/network.cpp src/network.h
@ -467,7 +467,7 @@ else()
endif() endif()
target_link_libraries(chat PRIVATE target_link_libraries(chat PRIVATE
Boost::describe Boost::json Boost::system Boost::describe Boost::json Boost::system
QCoro6::Core QCoro6::Network QCoro6::Core QCoro6::Network QCoro6::Qml
QXlsx QXlsx
SingleApplication SingleApplication
duckx::duckx duckx::duckx

View File

@ -48,92 +48,41 @@ ColumnLayout {
bottomPadding: 20 bottomPadding: 20
property int childWidth: 330 * theme.fontScale property int childWidth: 330 * theme.fontScale
property int childHeight: 400 + 166 * theme.fontScale property int childHeight: 400 + 166 * theme.fontScale
RemoteModelCard { Repeater {
width: parent.childWidth model: BuiltinProviderList
height: parent.childHeight delegate: RemoteModelCard {
providerBaseUrl: "https://api.groq.com/openai/v1/" required property var data
providerName: qsTr("Groq") width: parent.childWidth
providerImage: "qrc:/gpt4all/icons/groq.svg" height: parent.childHeight
providerDesc: qsTr('Groq offers a high-performance AI inference engine designed for low-latency and efficient processing. Optimized for real-time applications, Groqs technology is ideal for users who need fast responses from open large language models and other AI workloads.<br><br>Get your API key: <a href="https://console.groq.com/keys">https://groq.com/</a>') provider: data
modelWhitelist: [ providerBaseUrl: data.baseUrl
// last updated 2025-02-24 providerName: data.name
"deepseek-r1-distill-llama-70b", providerImage: data.icon
"deepseek-r1-distill-qwen-32b", providerDesc: ({
"gemma2-9b-it", '{20f963dc-1f99-441e-ad80-f30a0a06bcac}': qsTr(
"llama-3.1-8b-instant", 'Groq offers a high-performance AI inference engine designed for low-latency and ' +
"llama-3.2-1b-preview", 'efficient processing. Optimized for real-time applications, Groqs technology is ideal ' +
"llama-3.2-3b-preview", 'for users who need fast responses from open large language models and other AI ' +
"llama-3.3-70b-specdec", 'workloads.<br><br>Get your API key: ' +
"llama-3.3-70b-versatile", '<a href="https://console.groq.com/keys">https://groq.com/</a>'
"llama3-70b-8192", ),
"llama3-8b-8192", '{6f874c3a-f1ad-47f7-9129-755c5477146c}': qsTr(
"mixtral-8x7b-32768", 'OpenAI provides access to advanced AI models, including GPT-4 supporting a wide range ' +
"qwen-2.5-32b", 'of applications, from conversational AI to content generation and code completion.' +
"qwen-2.5-coder-32b", '<br><br>Get your API key: ' +
] '<a href="https://platform.openai.com/signup">https://openai.com/</a>'
} ),
RemoteModelCard { '{7ae617b3-c0b2-4d2c-9ff2-bc3f049494cc}': qsTr(
width: parent.childWidth 'Mistral AI specializes in efficient, open-weight language models optimized for various ' +
height: parent.childHeight 'natural language processing tasks. Their models are designed for flexibility and ' +
providerBaseUrl: "https://api.openai.com/v1/" 'performance, making them a solid option for applications requiring scalable AI ' +
providerName: qsTr("OpenAI") 'solutions.<br><br>Get your API key: <a href="https://mistral.ai/">https://mistral.ai/</a>'
providerImage: "qrc:/gpt4all/icons/openai.svg" ),
providerDesc: qsTr('OpenAI provides access to advanced AI models, including GPT-4 supporting a wide range of applications, from conversational AI to content generation and code completion.<br><br>Get your API key: <a href="https://platform.openai.com/signup">https://openai.com/</a>') })[data.id.toString()]
modelWhitelist: [ modelWhitelist: data.modelWhitelist
// last updated 2025-02-24 }
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-32k",
"gpt-4-turbo",
"gpt-4o",
]
}
RemoteModelCard {
width: parent.childWidth
height: parent.childHeight
providerBaseUrl: "https://api.mistral.ai/v1/"
providerName: qsTr("Mistral")
providerImage: "qrc:/gpt4all/icons/mistral.svg"
providerDesc: qsTr('Mistral AI specializes in efficient, open-weight language models optimized for various natural language processing tasks. Their models are designed for flexibility and performance, making them a solid option for applications requiring scalable AI solutions.<br><br>Get your API key: <a href="https://mistral.ai/">https://mistral.ai/</a>')
modelWhitelist: [
// last updated 2025-02-24
"codestral-2405",
"codestral-2411-rc5",
"codestral-2412",
"codestral-2501",
"codestral-latest",
"codestral-mamba-2407",
"codestral-mamba-latest",
"ministral-3b-2410",
"ministral-3b-latest",
"ministral-8b-2410",
"ministral-8b-latest",
"mistral-large-2402",
"mistral-large-2407",
"mistral-large-2411",
"mistral-large-latest",
"mistral-medium-2312",
"mistral-medium-latest",
"mistral-saba-2502",
"mistral-saba-latest",
"mistral-small-2312",
"mistral-small-2402",
"mistral-small-2409",
"mistral-small-2501",
"mistral-small-latest",
"mistral-tiny-2312",
"mistral-tiny-2407",
"mistral-tiny-latest",
"open-codestral-mamba",
"open-mistral-7b",
"open-mistral-nemo",
"open-mistral-nemo-2407",
"open-mixtral-8x22b",
"open-mixtral-8x22b-2404",
"open-mixtral-8x7b",
]
} }
/*
RemoteModelCard { RemoteModelCard {
width: parent.childWidth width: parent.childWidth
height: parent.childHeight height: parent.childHeight
@ -142,6 +91,7 @@ ColumnLayout {
providerImage: "qrc:/gpt4all/icons/antenna_3.svg" providerImage: "qrc:/gpt4all/icons/antenna_3.svg"
providerDesc: qsTr("The custom provider option allows users to connect their own OpenAI-compatible AI models or third-party inference services. This is useful for organizations with proprietary models or those leveraging niche AI providers not listed here.") providerDesc: qsTr("The custom provider option allows users to connect their own OpenAI-compatible AI models or third-party inference services. This is useful for organizations with proprietary models or those leveraging niche AI providers not listed here.")
} }
*/
} }
} }
} }

View File

@ -18,6 +18,7 @@ import localdocs
Rectangle { Rectangle {
required property var provider
property alias providerName: providerNameLabel.text property alias providerName: providerNameLabel.text
property alias providerImage: myimage.source property alias providerImage: myimage.source
property alias providerDesc: providerDescLabel.text property alias providerDesc: providerDescLabel.text
@ -100,18 +101,23 @@ Rectangle {
Layout.fillWidth: true Layout.fillWidth: true
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
wrapMode: Text.WrapAnywhere wrapMode: Text.WrapAnywhere
echoMode: TextField.Password
function showError() { function showError() {
messageToast.show(qsTr("ERROR: $API_KEY is empty.")); messageToast.show(qsTr("ERROR: $API_KEY is empty."));
apiKeyField.placeholderTextColor = theme.textErrorColor; apiKeyField.placeholderTextColor = theme.textErrorColor;
} }
Component.onCompleted: { text = provider.apiKey; }
onTextChanged: { onTextChanged: {
apiKeyField.placeholderTextColor = theme.mutedTextColor; apiKeyField.placeholderTextColor = theme.mutedTextColor;
if (!providerIsCustom) { if (!providerIsCustom && provider.setApiKeyQml(text)) {
let models = ModelList.remoteModelList(apiKeyField.text, providerBaseUrl); provider.listModelsQml().then(modelList => {
if (modelWhitelist !== null) if (modelList !== null) {
models = models.filter(m => modelWhitelist.includes(m)); if (modelWhitelist !== null)
myModelList.model = models; models = models.filter(m => modelWhitelist.includes(m));
myModelList.currentIndex = -1; myModelList.model = models;
myModelList.currentIndex = -1;
}
});
} }
} }
placeholderText: qsTr("enter $API_KEY") placeholderText: qsTr("enter $API_KEY")

View File

@ -38,7 +38,8 @@ protected:
}; };
class OllamaProvider : public QObject, public virtual ModelProvider { class OllamaProvider : public QObject, public virtual ModelProvider {
Q_OBJECT Q_GADGET
Q_PROPERTY(QUuid id READ id CONSTANT)
public: public:
~OllamaProvider() noexcept override = 0; ~OllamaProvider() noexcept override = 0;
@ -63,6 +64,8 @@ public:
class OllamaProviderCustom final : public OllamaProvider, public ModelProviderCustom { class OllamaProviderCustom final : public OllamaProvider, public ModelProviderCustom {
Q_OBJECT Q_OBJECT
Q_PROPERTY(QString name READ name NOTIFY nameChanged )
Q_PROPERTY(QUrl baseUrl READ baseUrl NOTIFY baseUrlChanged)
public: public:
/// Load an existing OllamaProvider from disk. /// Load an existing OllamaProvider from disk.

View File

@ -1,10 +1,13 @@
#include "llmodel_openai.h" #include "llmodel_openai.h"
#include "main.h"
#include "mysettings.h" #include "mysettings.h"
#include "utils.h" #include "utils.h"
#include <QCoro/QCoroAsyncGenerator> // IWYU pragma: keep #include <QCoro/QCoroAsyncGenerator> // IWYU pragma: keep
#include <QCoro/QCoroNetworkReply> // IWYU pragma: keep #include <QCoro/QCoroNetworkReply> // IWYU pragma: keep
#include <QCoro/QCoroTask> // IWYU pragma: keep
#include <boost/json.hpp> // IWYU pragma: keep
#include <fmt/format.h> #include <fmt/format.h>
#include <gpt4all-backend/formatters.h> // IWYU pragma: keep #include <gpt4all-backend/formatters.h> // IWYU pragma: keep
#include <gpt4all-backend/rest.h> #include <gpt4all-backend/rest.h>
@ -36,6 +39,7 @@
#include <stdexcept> #include <stdexcept>
#include <utility> #include <utility>
namespace json = boost::json;
using namespace Qt::Literals::StringLiterals; using namespace Qt::Literals::StringLiterals;
//#define DEBUG //#define DEBUG
@ -86,6 +90,14 @@ auto OpenaiGenerationParams::toMap() const -> QMap<QLatin1StringView, QVariant>
OpenaiProvider::~OpenaiProvider() noexcept = default; OpenaiProvider::~OpenaiProvider() noexcept = default;
Q_INVOKABLE bool OpenaiProvider::setApiKeyQml(QString value)
{
auto res = setApiKey(std::move(value));
if (!res)
qWarning().noquote() << "setApiKey failed:" << res.error().errorString();
return bool(res);
}
auto OpenaiProvider::supportedGenerationParams() const -> QSet<GenerationParam> auto OpenaiProvider::supportedGenerationParams() const -> QSet<GenerationParam>
{ {
using enum GenerationParam; using enum GenerationParam;
@ -96,9 +108,61 @@ auto OpenaiProvider::makeGenerationParams(const QMap<GenerationParam, QVariant>
-> OpenaiGenerationParams * -> OpenaiGenerationParams *
{ return new OpenaiGenerationParams(values); } { return new OpenaiGenerationParams(values); }
OpenaiProviderBuiltin::OpenaiProviderBuiltin(ProviderStore *store, QUuid id, QString name, QUrl baseUrl) auto OpenaiProvider::listModels() -> QCoro::Task<backend::DataOrRespErr<QStringList>>
{
auto *nam = networkAccessManager();
QNetworkRequest request(m_baseUrl.resolved(u"models"_s));
request.setHeader (QNetworkRequest::ContentTypeHeader, "application/json"_ba);
request.setRawHeader("Authorization"_ba, fmt::format("Bearer {}", m_apiKey).c_str());
std::unique_ptr<QNetworkReply> reply(nam->get(request));
QRestReply restReply(reply.get());
if (reply->error())
co_return std::unexpected(&restReply);
QStringList models;
try {
json::stream_parser parser;
auto coroReply = qCoro(*reply);
for (;;) {
auto chunk = co_await coroReply.readAll();
if (!restReply.isSuccess())
co_return std::unexpected(&restReply);
if (chunk.isEmpty()) {
Q_ASSERT(reply->atEnd());
break;
}
parser.write(chunk.data(), chunk.size());
}
parser.finish();
auto resp = parser.release().as_object();
for (auto &entry : resp.at("data").as_array())
models << json::value_to<QString>(entry.at("id"));
} catch (const boost::system::system_error &e) {
co_return std::unexpected(e);
}
co_return models;
}
QCoro::QmlTask OpenaiProvider::listModelsQml()
{
return [this]() -> QCoro::Task<QVariant> {
auto result = co_await listModels();
if (result)
co_return *result;
qWarning().noquote() << "OpenaiProvider::listModels failed:" << result.error().errorString();
co_return QVariant::fromValue(nullptr);
}();
}
OpenaiProviderBuiltin::OpenaiProviderBuiltin(ProviderStore *store, QUuid id, QString name, QUrl icon, QUrl baseUrl,
QStringList modelWhitelist)
: ModelProvider(std::move(id), std::move(name), std::move(baseUrl)) : ModelProvider(std::move(id), std::move(name), std::move(baseUrl))
, ModelProviderBuiltin(std::move(icon))
, ModelProviderMutable(store) , ModelProviderMutable(store)
, m_modelWhitelist(std::move(modelWhitelist))
{ {
auto res = m_store->acquire(m_id); auto res = m_store->acquire(m_id);
if (!res) if (!res)

View File

@ -4,9 +4,13 @@
#include "llmodel_description.h" #include "llmodel_description.h"
#include "llmodel_provider.h" #include "llmodel_provider.h"
#include <QCoro/QCoroQmlTask> // IWYU pragma: keep
#include <gpt4all-backend/ollama-client.h>
#include <QLatin1StringView> // IWYU pragma: keep #include <QLatin1StringView> // IWYU pragma: keep
#include <QObject> // IWYU pragma: keep #include <QObject> // IWYU pragma: keep
#include <QString> #include <QString>
#include <QStringList> // IWYU pragma: keep
#include <QUrl> #include <QUrl>
#include <QVariant> #include <QVariant>
#include <QtTypes> // IWYU pragma: keep #include <QtTypes> // IWYU pragma: keep
@ -17,6 +21,7 @@
class QNetworkAccessManager; class QNetworkAccessManager;
template <typename Key, typename T> class QMap; template <typename Key, typename T> class QMap;
template <typename T> class QSet; template <typename T> class QSet;
namespace QCoro { template <typename T> class Task; }
namespace gpt4all::ui { namespace gpt4all::ui {
@ -42,7 +47,8 @@ protected:
class OpenaiProvider : public QObject, public virtual ModelProvider { class OpenaiProvider : public QObject, public virtual ModelProvider {
Q_OBJECT Q_OBJECT
Q_PROPERTY(QString apiKey READ apiKey WRITE setApiKey NOTIFY apiKeyChanged) Q_PROPERTY(QUuid id READ id CONSTANT )
Q_PROPERTY(QString apiKey READ apiKey NOTIFY apiKeyChanged)
protected: protected:
explicit OpenaiProvider() = default; explicit OpenaiProvider() = default;
@ -57,11 +63,15 @@ public:
[[nodiscard]] const QString &apiKey() const { return m_apiKey; } [[nodiscard]] const QString &apiKey() const { return m_apiKey; }
virtual void setApiKey(QString value) = 0; [[nodiscard]] virtual DataStoreResult<> setApiKey(QString value) = 0;
Q_INVOKABLE bool setApiKeyQml(QString value);
auto supportedGenerationParams() const -> QSet<GenerationParam> override; auto supportedGenerationParams() const -> QSet<GenerationParam> override;
auto makeGenerationParams(const QMap<GenerationParam, QVariant> &values) const -> OpenaiGenerationParams * override; auto makeGenerationParams(const QMap<GenerationParam, QVariant> &values) const -> OpenaiGenerationParams * override;
auto listModels() -> QCoro::Task<backend::DataOrRespErr<QStringList>>;
Q_INVOKABLE QCoro::QmlTask listModelsQml();
Q_SIGNALS: Q_SIGNALS:
void apiKeyChanged(const QString &value); void apiKeyChanged(const QString &value);
@ -69,23 +79,33 @@ protected:
QString m_apiKey; QString m_apiKey;
}; };
class OpenaiProviderBuiltin : public OpenaiProvider, public ModelProviderMutable { class OpenaiProviderBuiltin : public OpenaiProvider, public ModelProviderBuiltin, public ModelProviderMutable {
Q_OBJECT Q_OBJECT
Q_PROPERTY(QString name READ name CONSTANT) Q_PROPERTY(QString name READ name CONSTANT)
Q_PROPERTY(QUrl baseUrl READ baseUrl CONSTANT) Q_PROPERTY(QUrl icon READ icon CONSTANT)
Q_PROPERTY(QUrl baseUrl READ baseUrl CONSTANT)
Q_PROPERTY(QStringList modelWhitelist READ modelWhitelist CONSTANT)
public: public:
/// Create a new built-in OpenAI provider, loading its API key from disk if known. /// Create a new built-in OpenAI provider, loading its API key from disk if known.
explicit OpenaiProviderBuiltin(ProviderStore *store, QUuid id, QString name, QUrl baseUrl); explicit OpenaiProviderBuiltin(ProviderStore *store, QUuid id, QString name, QUrl icon, QUrl baseUrl,
QStringList modelWhitelist);
void setApiKey(QString value) override { setMemberProp<QString>(&OpenaiProviderBuiltin::m_apiKey, "apiKey", std::move(value)); } [[nodiscard]] const QStringList &modelWhitelist() { return m_modelWhitelist; }
[[nodiscard]] DataStoreResult<> setApiKey(QString value) override
{ return setMemberProp<QString>(&OpenaiProviderBuiltin::m_apiKey, "apiKey", std::move(value), /*createName*/ m_name); }
protected: protected:
auto asData() -> ModelProviderData override; auto asData() -> ModelProviderData override;
QStringList m_modelWhitelist;
}; };
class OpenaiProviderCustom final : public OpenaiProvider, public ModelProviderCustom { class OpenaiProviderCustom final : public OpenaiProvider, public ModelProviderCustom {
Q_OBJECT Q_OBJECT
Q_PROPERTY(QString name READ name NOTIFY nameChanged )
Q_PROPERTY(QUrl baseUrl READ baseUrl NOTIFY baseUrlChanged)
public: public:
/// Load an existing OpenaiProvider from disk. /// Load an existing OpenaiProvider from disk.
@ -94,7 +114,8 @@ public:
/// Create a new OpenaiProvider on disk. /// Create a new OpenaiProvider on disk.
explicit OpenaiProviderCustom(ProviderStore *store, QString name, QUrl baseUrl, QString apiKey); explicit OpenaiProviderCustom(ProviderStore *store, QString name, QUrl baseUrl, QString apiKey);
void setApiKey(QString value) override { setMemberProp<QString>(&OpenaiProviderCustom::m_apiKey, "apiKey", std::move(value)); } [[nodiscard]] DataStoreResult<> setApiKey(QString value) override
{ return setMemberProp<QString>(&OpenaiProviderCustom::m_apiKey, "apiKey", std::move(value)); }
Q_SIGNALS: Q_SIGNALS:
void nameChanged (const QString &value); void nameChanged (const QString &value);

View File

@ -60,13 +60,26 @@ ProviderRegistry::ProviderRegistry(PathSet paths)
load(); load();
} }
namespace {
class ProviderRegistryInternal : public ProviderRegistry {};
Q_GLOBAL_STATIC(ProviderRegistryInternal, providerRegistry)
}
ProviderRegistry *ProviderRegistry::globalInstance()
{ return providerRegistry(); }
void ProviderRegistry::load() void ProviderRegistry::load()
{ {
size_t i = 0;
for (auto &p : s_builtinProviders) { // (not all builtin providers are stored) for (auto &p : s_builtinProviders) { // (not all builtin providers are stored)
auto provider = std::make_shared<OpenaiProviderBuiltin>(&m_builtinStore, p.id, p.name, p.base_url); auto provider = std::make_shared<OpenaiProviderBuiltin>(
&m_builtinStore, p.id, p.name, p.icon, p.base_url,
QStringList(p.model_whitelist.begin(), p.model_whitelist.end())
);
auto [_, unique] = m_providers.emplace(p.id, std::move(provider)); auto [_, unique] = m_providers.emplace(p.id, std::move(provider));
if (!unique) if (!unique)
throw std::logic_error(fmt::format("duplicate builtin provider id: {}", p.id.toString())); throw std::logic_error(fmt::format("duplicate builtin provider id: {}", p.id.toString()));
m_builtinProviders[i++] = p.id;
} }
for (auto &p : m_customStore.list()) { // disk is source of truth for custom providers for (auto &p : m_customStore.list()) { // disk is source of truth for custom providers
if (!p.custom_details) { if (!p.custom_details) {
@ -91,11 +104,12 @@ void ProviderRegistry::load()
auto [_, unique] = m_providers.emplace(p.id, std::move(provider)); auto [_, unique] = m_providers.emplace(p.id, std::move(provider));
if (!unique) if (!unique)
qWarning() << "ignoring duplicate custom provider with id:" << p.id; qWarning() << "ignoring duplicate custom provider with id:" << p.id;
m_customProviders.push_back(std::make_unique<QUuid>(p.id));
} }
} }
[[nodiscard]] [[nodiscard]]
bool ProviderRegistry::add(std::unique_ptr<ModelProviderCustom> provider) bool ProviderRegistry::add(std::shared_ptr<ModelProviderCustom> provider)
{ {
auto [it, unique] = m_providers.emplace(provider->id(), std::move(provider)); auto [it, unique] = m_providers.emplace(provider->id(), std::move(provider));
if (unique) { if (unique) {
@ -105,13 +119,21 @@ bool ProviderRegistry::add(std::unique_ptr<ModelProviderCustom> provider)
return unique; return unique;
} }
auto ProviderRegistry::customProviderAt(size_t i) const -> const ModelProviderCustom * auto ProviderRegistry::customProviderAt(size_t i) const -> ModelProviderCustom *
{ {
auto it = m_providers.find(*m_customProviders.at(i)); auto it = m_providers.find(*m_customProviders.at(i));
Q_ASSERT(it != m_providers.end()); Q_ASSERT(it != m_providers.end());
return &dynamic_cast<ModelProviderCustom &>(*it->second); return &dynamic_cast<ModelProviderCustom &>(*it->second);
} }
auto ProviderRegistry::builtinProviderAt(size_t i) const -> ModelProviderBuiltin *
{
auto it = m_providers.find(m_builtinProviders.at(i));
Q_ASSERT(it != m_providers.end());
return &dynamic_cast<ModelProviderBuiltin &>(*it->second);
}
auto ProviderRegistry::getSubdirs() -> PathSet auto ProviderRegistry::getSubdirs() -> PathSet
{ {
auto *mysettings = MySettings::globalInstance(); auto *mysettings = MySettings::globalInstance();
@ -135,19 +157,31 @@ void ProviderRegistry::onModelPathChanged()
} }
} }
CustomProviderList::CustomProviderList(QPointer<ProviderRegistry> registry) auto BuiltinProviderList::roleNames() const -> QHash<int, QByteArray>
: m_registry(std::move(registry) ) { return { { Qt::DisplayRole, "data"_ba } }; }
, m_size (m_registry->customProviderCount())
QVariant BuiltinProviderList::data(const QModelIndex &index, int role) const
{ {
connect(m_registry, &ProviderRegistry::customProviderAdded, this, &CustomProviderList::onCustomProviderAdded); auto *registry = ProviderRegistry::globalInstance();
connect(m_registry, &ProviderRegistry::aboutToBeCleared, this, &CustomProviderList::onAboutToBeCleared, if (index.isValid() && index.row() < rowCount() && role == Qt::DisplayRole)
return QVariant::fromValue(registry->builtinProviderAt(index.row())->asQObject());
return {};
}
CustomProviderList::CustomProviderList()
: m_size(ProviderRegistry::globalInstance()->customProviderCount())
{
auto *registry = ProviderRegistry::globalInstance();
connect(registry, &ProviderRegistry::customProviderAdded, this, &CustomProviderList::onCustomProviderAdded);
connect(registry, &ProviderRegistry::aboutToBeCleared, this, &CustomProviderList::onAboutToBeCleared,
Qt::DirectConnection); Qt::DirectConnection);
} }
QVariant CustomProviderList::data(const QModelIndex &index, int role) const QVariant CustomProviderList::data(const QModelIndex &index, int role) const
{ {
auto *registry = ProviderRegistry::globalInstance();
if (index.isValid() && index.row() < rowCount() && role == Qt::DisplayRole) if (index.isValid() && index.row() < rowCount() && role == Qt::DisplayRole)
return QVariant::fromValue(m_registry->customProviderAt(index.row())); return QVariant::fromValue(registry->customProviderAt(index.row())->asQObject());
return {}; return {};
} }
@ -165,10 +199,10 @@ void CustomProviderList::onAboutToBeCleared()
endResetModel(); endResetModel();
} }
bool CustomProviderListSort::lessThan(const QModelIndex &left, const QModelIndex &right) const bool ProviderListSort::lessThan(const QModelIndex &left, const QModelIndex &right) const
{ {
auto *leftData = sourceModel()->data(left ).value<ModelProviderCustom *>(); auto *leftData = sourceModel()->data(left ).value<ModelProvider *>();
auto *rightData = sourceModel()->data(right).value<ModelProviderCustom *>(); auto *rightData = sourceModel()->data(right).value<ModelProvider *>();
if (leftData && rightData) if (leftData && rightData)
return QString::localeAwareCompare(leftData->name(), rightData->name()) < 0; return QString::localeAwareCompare(leftData->name(), rightData->name()) < 0;
return true; return true;

View File

@ -6,7 +6,6 @@
#include <QAbstractListModel> #include <QAbstractListModel>
#include <QObject> #include <QObject>
#include <QPointer>
#include <QQmlEngine> // IWYU pragma: keep #include <QQmlEngine> // IWYU pragma: keep
#include <QSortFilterProxyModel> #include <QSortFilterProxyModel>
#include <QString> #include <QString>
@ -18,12 +17,15 @@
#include <cstddef> #include <cstddef>
#include <filesystem> #include <filesystem>
#include <memory> #include <memory>
#include <optional>
#include <string_view> #include <string_view>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
class QByteArray;
class QJSEngine; class QJSEngine;
template <typename Key, typename T> class QHash;
namespace gpt4all::ui { namespace gpt4all::ui {
@ -60,9 +62,6 @@ protected:
}; };
class ModelProvider { class ModelProvider {
Q_GADGET
Q_PROPERTY(QUuid id READ id CONSTANT)
protected: protected:
explicit ModelProvider(QUuid id, QString name, QUrl baseUrl) // create built-in or load explicit ModelProvider(QUuid id, QString name, QUrl baseUrl) // create built-in or load
: m_id(std::move(id)), m_name(std::move(name)), m_baseUrl(std::move(baseUrl)) {} : m_id(std::move(id)), m_name(std::move(name)), m_baseUrl(std::move(baseUrl)) {}
@ -92,10 +91,20 @@ protected:
QUrl m_baseUrl; QUrl m_baseUrl;
}; };
class ModelProviderBuiltin : public virtual ModelProvider {
protected:
explicit ModelProviderBuiltin(QUrl icon)
: m_icon(std::move(icon)) {}
public:
[[nodiscard]] const QUrl &icon() const { return m_icon; }
protected:
QUrl m_icon;
};
// Mixin with no public interface providing basic load/save // Mixin with no public interface providing basic load/save
class ModelProviderMutable : public virtual ModelProvider { class ModelProviderMutable : public virtual ModelProvider {
Q_GADGET
protected: protected:
explicit ModelProviderMutable(ProviderStore *store) explicit ModelProviderMutable(ProviderStore *store)
: m_store(store) {} : m_store(store) {}
@ -107,53 +116,54 @@ protected:
virtual auto asData() -> ModelProviderData = 0; virtual auto asData() -> ModelProviderData = 0;
template <typename T, typename S, typename C> template <typename T, typename S, typename C>
void setMemberProp(this S &self, T C::* member, std::string_view name, T value); [[nodiscard]] DataStoreResult<> setMemberProp(this S &self, T C::* member, std::string_view name, T value,
std::optional<QString> createName = {});
ProviderStore *m_store; ProviderStore *m_store;
}; };
class ModelProviderCustom : public ModelProviderMutable { class ModelProviderCustom : public ModelProviderMutable {
Q_GADGET
Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged )
Q_PROPERTY(QUrl baseUrl READ baseUrl WRITE setBaseUrl NOTIFY baseUrlChanged)
protected: protected:
explicit ModelProviderCustom(ProviderStore *store) explicit ModelProviderCustom(ProviderStore *store)
: ModelProviderMutable(store) {} : ModelProviderMutable(store) {}
public: public:
// setters // setters
void setName (QString value) { setMemberProp<QString>(&ModelProviderCustom::m_name, "name", std::move(value)); } [[nodiscard]] DataStoreResult<> setName (QString value)
void setBaseUrl(QUrl value) { setMemberProp<QUrl >(&ModelProviderCustom::m_baseUrl, "baseUrl", std::move(value)); } { return setMemberProp<QString>(&ModelProviderCustom::m_name, "name", std::move(value)); }
[[nodiscard]] DataStoreResult<> setBaseUrl(QUrl value)
{ return setMemberProp<QUrl >(&ModelProviderCustom::m_baseUrl, "baseUrl", std::move(value)); }
}; };
class ProviderRegistry : public QObject { class ProviderRegistry : public QObject {
Q_OBJECT Q_OBJECT
QML_ELEMENT
QML_SINGLETON
private: private:
struct PathSet { std::filesystem::path builtin, custom; }; struct PathSet { std::filesystem::path builtin, custom; };
struct BuiltinProviderData { struct BuiltinProviderData {
QUuid id; QUuid id;
QString name; QString name;
QUrl base_url; QUrl icon;
QUrl base_url;
std::span<const QString> model_whitelist;
}; };
protected: protected:
explicit ProviderRegistry(PathSet paths); explicit ProviderRegistry(PathSet paths);
explicit ProviderRegistry(): ProviderRegistry(getSubdirs()) {}
public: public:
static ProviderRegistry *create(QQmlEngine *, QJSEngine *) { return new ProviderRegistry(getSubdirs()); } static ProviderRegistry *globalInstance();
[[nodiscard]] bool add(std::unique_ptr<ModelProviderCustom> provider);
[[nodiscard]] bool add(std::shared_ptr<ModelProviderCustom> provider);
auto operator[](const QUuid &id) -> const ModelProvider * { return m_providers.at(id).get(); }
// TODO(jared): implement a way to remove custom providers via the model // TODO(jared): implement a way to remove custom providers via the model
[[nodiscard]] size_t customProviderCount() const [[nodiscard]] size_t customProviderCount () const { return m_customProviders.size(); }
{ return m_customProviders.size(); } [[nodiscard]] auto customProviderAt (size_t i) const -> ModelProviderCustom *;
[[nodiscard]] auto customProviderAt(size_t i) const -> const ModelProviderCustom *; [[nodiscard]] size_t builtinProviderCount() const { return m_builtinProviders.size(); }
auto operator[](const QUuid &id) -> ModelProviderCustom * [[nodiscard]] auto builtinProviderAt (size_t i) const -> ModelProviderBuiltin *;
{ return &dynamic_cast<ModelProviderCustom &>(*m_providers.at(id)); }
Q_SIGNALS: Q_SIGNALS:
void customProviderAdded(size_t index); void customProviderAdded(size_t index);
@ -174,15 +184,36 @@ private:
ProviderStore m_builtinStore; ProviderStore m_builtinStore;
std::unordered_map<QUuid, std::shared_ptr<ModelProvider>> m_providers; std::unordered_map<QUuid, std::shared_ptr<ModelProvider>> m_providers;
std::vector<std::unique_ptr<QUuid>> m_customProviders; std::vector<std::unique_ptr<QUuid>> m_customProviders;
std::array<QUuid, N_BUILTIN> m_builtinProviders;
};
// TODO: api keys are allowed to change for here and also below. That should emit dataChanged.
class BuiltinProviderList : public QAbstractListModel {
Q_OBJECT
QML_SINGLETON
QML_ELEMENT
public:
explicit BuiltinProviderList()
: m_size(ProviderRegistry::globalInstance()->builtinProviderCount()) {}
static BuiltinProviderList *create(QQmlEngine *, QJSEngine *) { return new BuiltinProviderList(); }
auto roleNames() const -> QHash<int, QByteArray> override;
int rowCount(const QModelIndex &parent = {}) const override
{ Q_UNUSED(parent) return int(m_size); }
QVariant data(const QModelIndex &index, int role) const override;
private:
size_t m_size;
}; };
class CustomProviderList : public QAbstractListModel { class CustomProviderList : public QAbstractListModel {
Q_OBJECT Q_OBJECT
protected:
explicit CustomProviderList(QPointer<ProviderRegistry> registry);
public: public:
explicit CustomProviderList();
int rowCount(const QModelIndex &parent = {}) const override int rowCount(const QModelIndex &parent = {}) const override
{ Q_UNUSED(parent) return int(m_size); } { Q_UNUSED(parent) return int(m_size); }
QVariant data(const QModelIndex &index, int role) const override; QVariant data(const QModelIndex &index, int role) const override;
@ -192,15 +223,28 @@ private Q_SLOTS:
void onAboutToBeCleared(); void onAboutToBeCleared();
private: private:
QPointer<ProviderRegistry> m_registry; size_t m_size;
size_t m_size;
}; };
class CustomProviderListSort : public QSortFilterProxyModel { // todo: don't have singletons use singletons directly
// TODO: actually use the provider sort, here, rather than unsorted, for builtins
class ProviderListSort : public QSortFilterProxyModel {
Q_OBJECT Q_OBJECT
QML_SINGLETON
QML_ELEMENT
private:
explicit ProviderListSort() { setSourceModel(&m_model); }
public:
static ProviderListSort *create(QQmlEngine *, QJSEngine *) { return new ProviderListSort(); }
protected: protected:
bool lessThan(const QModelIndex &left, const QModelIndex &right) const override; bool lessThan(const QModelIndex &left, const QModelIndex &right) const override;
private:
// TODO: support custom providers as well
BuiltinProviderList m_model;
}; };

View File

@ -13,17 +13,19 @@ void GenerationParams::tryParseValue(this S &self, QMap<GenerationParam, QVarian
} }
template <typename T, typename S, typename C> template <typename T, typename S, typename C>
void ModelProviderMutable::setMemberProp(this S &self, T C::* member, std::string_view name, T value) auto ModelProviderMutable::setMemberProp(this S &self, T C::* member, std::string_view name, T value,
std::optional<QString> createName) -> DataStoreResult<>
{ {
auto &mpc = static_cast<ModelProviderMutable &>(self); auto &mpc = static_cast<ModelProviderMutable &>(self);
auto &cur = self.*member; auto &cur = self.*member;
if (cur != value) { if (cur != value) {
cur = std::move(value); cur = std::move(value);
auto data = mpc.asData(); auto data = mpc.asData();
if (auto res = mpc.m_store->setData(std::move(data)); !res) if (auto res = mpc.m_store->setData(std::move(data), createName); !res)
res.error().raise(); return res;
QMetaObject::invokeMethod(self.asQObject(), fmt::format("{}Changed", name).c_str(), cur); QMetaObject::invokeMethod(self.asQObject(), fmt::format("{}Changed", name).c_str(), cur);
} }
return {};
} }

View File

@ -6,27 +6,94 @@ using namespace Qt::StringLiterals;
namespace gpt4all::ui { namespace gpt4all::ui {
// TODO: use these in the constructor of ProviderRegistry static const QString MODEL_WHITELIST_GROQ[] {
// TODO: we have to be careful to reserve these names for ProviderStore purposes, so the user can't write JSON files that alias them. // last updated 2025-02-24
// this *is a problem*, because we want to be able to safely introduce these. u"deepseek-r1-distill-llama-70b"_s,
// so we need a different namespace, i.e. a *different directory*. u"deepseek-r1-distill-qwen-32b"_s,
u"gemma2-9b-it"_s,
u"llama-3.1-8b-instant"_s,
u"llama-3.2-1b-preview"_s,
u"llama-3.2-3b-preview"_s,
u"llama-3.3-70b-specdec"_s,
u"llama-3.3-70b-versatile"_s,
u"llama3-70b-8192"_s,
u"llama3-8b-8192"_s,
u"mixtral-8x7b-32768"_s,
u"qwen-2.5-32b"_s,
u"qwen-2.5-coder-32b"_s,
};
static const QString MODEL_WHITELIST_OPENAI[] {
// last updated 2025-02-24
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-32k",
"gpt-4-turbo",
"gpt-4o",
};
static const QString MODEL_WHITELIST_MISTRAL[] {
// last updated 2025-02-24
"codestral-2405",
"codestral-2411-rc5",
"codestral-2412",
"codestral-2501",
"codestral-latest",
"codestral-mamba-2407",
"codestral-mamba-latest",
"ministral-3b-2410",
"ministral-3b-latest",
"ministral-8b-2410",
"ministral-8b-latest",
"mistral-large-2402",
"mistral-large-2407",
"mistral-large-2411",
"mistral-large-latest",
"mistral-medium-2312",
"mistral-medium-latest",
"mistral-saba-2502",
"mistral-saba-latest",
"mistral-small-2312",
"mistral-small-2402",
"mistral-small-2409",
"mistral-small-2501",
"mistral-small-latest",
"mistral-tiny-2312",
"mistral-tiny-2407",
"mistral-tiny-latest",
"open-codestral-mamba",
"open-mistral-7b",
"open-mistral-nemo",
"open-mistral-nemo-2407",
"open-mixtral-8x22b",
"open-mixtral-8x22b-2404",
"open-mixtral-8x7b",
};
const std::array< const std::array<
ProviderRegistry::BuiltinProviderData, ProviderRegistry::N_BUILTIN ProviderRegistry::BuiltinProviderData, ProviderRegistry::N_BUILTIN
> ProviderRegistry::s_builtinProviders { > ProviderRegistry::s_builtinProviders {
BuiltinProviderData { BuiltinProviderData {
.id = QUuid("20f963dc-1f99-441e-ad80-f30a0a06bcac"), .id = QUuid("20f963dc-1f99-441e-ad80-f30a0a06bcac"),
.name = u"Groq"_s, .name = u"Groq"_s,
.base_url = u"https://api.groq.com/openai/v1/"_s, .icon = u"qrc:/gpt4all/icons/groq.svg"_s,
.base_url = u"https://api.groq.com/openai/v1/"_s,
.model_whitelist = MODEL_WHITELIST_GROQ,
}, },
BuiltinProviderData { BuiltinProviderData {
.id = QUuid("6f874c3a-f1ad-47f7-9129-755c5477146c"), .id = QUuid("6f874c3a-f1ad-47f7-9129-755c5477146c"),
.name = u"OpenAI"_s, .name = u"OpenAI"_s,
.base_url = u"https://api.openai.com/v1/"_s, .icon = u"qrc:/gpt4all/icons/openai.svg"_s,
.base_url = u"https://api.openai.com/v1/"_s,
.model_whitelist = MODEL_WHITELIST_OPENAI,
}, },
BuiltinProviderData { BuiltinProviderData {
.id = QUuid("7ae617b3-c0b2-4d2c-9ff2-bc3f049494cc"), .id = QUuid("7ae617b3-c0b2-4d2c-9ff2-bc3f049494cc"),
.name = u"Mistral"_s, .name = u"Mistral"_s,
.base_url = u"https://api.mistral.ai/v1/"_s, .icon = u"qrc:/gpt4all/icons/mistral.svg"_s,
.base_url = u"https://api.mistral.ai/v1/"_s,
.model_whitelist = MODEL_WHITELIST_MISTRAL,
}, },
}; };

View File

@ -52,6 +52,9 @@
using namespace Qt::Literals::StringLiterals; using namespace Qt::Literals::StringLiterals;
namespace gpt4all::ui {
static void raiseWindow(QWindow *window) static void raiseWindow(QWindow *window)
{ {
#ifdef Q_OS_WINDOWS #ifdef Q_OS_WINDOWS
@ -70,8 +73,19 @@ static void raiseWindow(QWindow *window)
#endif #endif
} }
Q_GLOBAL_STATIC(QNetworkAccessManager, globalNetworkAccessManager)
QNetworkAccessManager *networkAccessManager()
{ return globalNetworkAccessManager(); }
} // namespace gpt4all::ui
int main(int argc, char *argv[]) int main(int argc, char *argv[])
{ {
using namespace gpt4all::ui;
#ifndef GPT4ALL_USE_QTPDF #ifndef GPT4ALL_USE_QTPDF
FPDF_InitLibrary(); FPDF_InitLibrary();
#endif #endif

12
gpt4all-chat/src/main.h Normal file
View File

@ -0,0 +1,12 @@
#pragma once
class QNetworkAccessManager;
namespace gpt4all::ui {
QNetworkAccessManager *networkAccessManager();
} // namespace gpt4all::ui

View File

@ -2365,56 +2365,3 @@ void ModelList::handleDiscoveryItemErrorOccurred(QNetworkReply::NetworkError cod
qWarning() << u"ERROR: Discovery item failed with error code \"%1-%2\""_s qWarning() << u"ERROR: Discovery item failed with error code \"%1-%2\""_s
.arg(code).arg(reply->errorString()).toStdString(); .arg(code).arg(reply->errorString()).toStdString();
} }
QStringList ModelList::remoteModelList(const QString &apiKey, const QUrl &baseUrl)
{
QStringList modelList;
// Create the request
QNetworkRequest request;
request.setUrl(baseUrl.resolved(QUrl("models")));
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
// Add the Authorization header
const QString bearerToken = QString("Bearer %1").arg(apiKey);
request.setRawHeader("Authorization", bearerToken.toUtf8());
// Make the GET request
QNetworkReply *reply = m_networkManager.get(request);
// We use a local event loop to wait for the request to complete
QEventLoop loop;
connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit);
loop.exec();
// Check for errors
if (reply->error() == QNetworkReply::NoError) {
// Parse the JSON response
const QByteArray responseData = reply->readAll();
const QJsonDocument jsonDoc = QJsonDocument::fromJson(responseData);
if (!jsonDoc.isNull() && jsonDoc.isObject()) {
QJsonObject rootObj = jsonDoc.object();
QJsonValue dataValue = rootObj.value("data");
if (dataValue.isArray()) {
QJsonArray dataArray = dataValue.toArray();
for (const QJsonValue &val : dataArray) {
if (val.isObject()) {
QJsonObject obj = val.toObject();
const QString modelId = obj.value("id").toString();
modelList.append(modelId);
}
}
}
}
} else {
// Handle network error (e.g. print it to qDebug)
qWarning() << "Error retrieving models:" << reply->errorString();
}
// Clean up
reply->deleteLater();
return modelList;
}

View File

@ -546,8 +546,6 @@ public:
Q_INVOKABLE void discoverSearch(const QString &discover); Q_INVOKABLE void discoverSearch(const QString &discover);
Q_INVOKABLE QStringList remoteModelList(const QString &apiKey, const QUrl &baseUrl);
Q_SIGNALS: Q_SIGNALS:
void countChanged(); void countChanged();
void installedModelsChanged(); void installedModelsChanged();

View File

@ -26,6 +26,18 @@ using namespace Qt::StringLiterals;
namespace gpt4all::ui { namespace gpt4all::ui {
DataStoreError::DataStoreError(std::error_code e)
: m_error(e)
, m_errorString(QString::fromStdString(e.message()))
{}
DataStoreError::DataStoreError(const sys::system_error &e)
: m_error(e.code())
, m_errorString(QString::fromUtf8(e.what()))
{
Q_ASSERT(e.code());
}
DataStoreError::DataStoreError(const QFileDevice *file) DataStoreError::DataStoreError(const QFileDevice *file)
: m_error(file->error()) : m_error(file->error())
, m_errorString(file->errorString()) , m_errorString(file->errorString())
@ -33,13 +45,6 @@ DataStoreError::DataStoreError(const QFileDevice *file)
Q_ASSERT(file->error()); Q_ASSERT(file->error());
} }
DataStoreError::DataStoreError(const boost::system::system_error &e)
: m_error(e.code())
, m_errorString(QString::fromUtf8(e.what()))
{
Q_ASSERT(e.code());
}
DataStoreError::DataStoreError(QString e) DataStoreError::DataStoreError(QString e)
: m_error() : m_error()
, m_errorString(e) , m_errorString(e)
@ -48,9 +53,10 @@ DataStoreError::DataStoreError(QString e)
void DataStoreError::raise() const void DataStoreError::raise() const
{ {
std::visit(Overloaded { std::visit(Overloaded {
[&](QFileDevice::FileError e) { throw FileError(m_errorString, e); }, [&](std::error_code e) { throw std::system_error(e); },
[&](boost::system::error_code e) { throw std::runtime_error(m_errorString.toUtf8().constData()); }, [&](sys::error_code e) { throw std::runtime_error(m_errorString.toUtf8().constData()); },
[&](std::monostate ) { throw std::runtime_error(m_errorString.toUtf8().constData()); }, [&](QFileDevice::FileError e) { throw FileError(m_errorString, e); },
[&](std::monostate ) { throw std::runtime_error(m_errorString.toUtf8().constData()); },
}, m_error); }, m_error);
Q_UNREACHABLE(); Q_UNREACHABLE();
} }
@ -63,7 +69,20 @@ auto DataStoreBase::reload() -> DataStoreResult<>
json::stream_parser parser; json::stream_parser parser;
QFile file; QFile file;
for (auto &entry : fs::directory_iterator(m_path)) { fs::directory_iterator it;
try {
it = fs::directory_iterator(m_path);
} catch (const fs::filesystem_error &e) {
if (e.code() == std::errc::no_such_file_or_directory) {
fs::create_directories(m_path);
return {}; // brand new dir, nothing to load
}
throw;
}
for (auto &entry : it) {
if (!entry.is_regular_file())
continue; // skip directories and such
file.setFileName(entry.path()); file.setFileName(entry.path());
if (!file.open(QFile::ReadOnly)) { if (!file.open(QFile::ReadOnly)) {
qWarning().noquote() << "skipping unopenable file:" << file.fileName(); qWarning().noquote() << "skipping unopenable file:" << file.fileName();
@ -71,7 +90,7 @@ auto DataStoreBase::reload() -> DataStoreResult<>
} }
auto jv = read(file, parser); auto jv = read(file, parser);
if (!jv) { if (!jv) {
(qWarning().nospace() << "skipping " << file.fileName() << "because of read error: ").noquote() (qWarning().nospace() << "skipping " << file.fileName() << " because of read error: ").noquote()
<< jv.error().errorString(); << jv.error().errorString();
} else if (auto [unique, uuid] = cacheInsert(*jv); !unique) } else if (auto [unique, uuid] = cacheInsert(*jv); !unique)
qWarning() << "skipping duplicate data store entry:" << uuid; qWarning() << "skipping duplicate data store entry:" << uuid;
@ -89,7 +108,7 @@ auto DataStoreBase::setPath(fs::path path) -> DataStoreResult<>
return {}; return {};
} }
auto DataStoreBase::getFilePath(const QString &name) -> std::filesystem::path auto DataStoreBase::getFilePath(const QString &name) -> fs::path
{ return m_path / fmt::format("{}.json", QLatin1StringView(normalizeName(name))); } { return m_path / fmt::format("{}.json", QLatin1StringView(normalizeName(name))); }
auto DataStoreBase::openNew(const QString &name) -> DataStoreResult<std::unique_ptr<QFile>> auto DataStoreBase::openNew(const QString &name) -> DataStoreResult<std::unique_ptr<QFile>>
@ -106,7 +125,7 @@ auto DataStoreBase::openNew(const QString &name) -> DataStoreResult<std::unique_
auto DataStoreBase::openExisting(const QString &name, bool allowCreate) -> DataStoreResult<std::unique_ptr<QSaveFile>> auto DataStoreBase::openExisting(const QString &name, bool allowCreate) -> DataStoreResult<std::unique_ptr<QSaveFile>>
{ {
auto path = getFilePath(name); auto path = getFilePath(name);
if (!QFile::exists(path)) if (!allowCreate && !QFile::exists(path))
return std::unexpected(sys::system_error( return std::unexpected(sys::system_error(
std::make_error_code(std::errc::no_such_file_or_directory), path.string() std::make_error_code(std::errc::no_such_file_or_directory), path.string()
)); ));
@ -119,33 +138,81 @@ auto DataStoreBase::openExisting(const QString &name, bool allowCreate) -> DataS
return file; return file;
} }
auto DataStoreBase::read(QFileDevice &file, boost::json::stream_parser &parser) -> DataStoreResult<boost::json::value> auto DataStoreBase::read(QFileDevice &file, json::stream_parser &parser) -> DataStoreResult<json::value>
{ {
for (;;) { // chunk stream
auto chunk = file.read(JSON_BUFSIZ); auto iterChunks = [&] -> tl::generator<DataStoreResult<QByteArray>> {
if (file.error()) for (;;) {
return std::unexpected(&file); auto chunk = file.read(JSON_BUFSIZ);
if (chunk.isEmpty()) { if (file.error()) {
Q_ASSERT(file.atEnd()); DataStoreResult<QByteArray> res(std::unexpect, &file);
break; co_yield res;
}
if (chunk.isEmpty()) {
Q_ASSERT(file.atEnd());
break;
}
DataStoreResult<QByteArray> res(std::move(chunk));
co_yield res;
} }
parser.write(chunk.data(), chunk.size()); };
auto inner = [&] -> DataStoreResult<> {
bool partialRead = false;
auto chunkIt = iterChunks();
// read JSON data
for (auto &chunk : chunkIt) {
if (!chunk)
return std::unexpected(chunk.error());
size_t nRead = parser.write_some(chunk->data(), chunk->size());
// consume trailing whitespace in chunk
if (nRead < chunk->size()) {
auto rest = QByteArrayView(*chunk).slice(nRead);
if (!rest.trimmed().isEmpty())
return std::unexpected(u"unexpected data after json: \"%1\""_s.arg(QByteArray(rest)));
partialRead = true;
break;
}
}
// consume trailing whitespace in file
if (partialRead) {
for (auto &chunk : chunkIt) {
if (!chunk)
return std::unexpected(chunk.error());
if (!chunk->trimmed().isEmpty())
return std::unexpected(u"unexpected data after json: \"%1\""_s.arg(*chunk));
}
}
return {};
};
auto res = inner();
if (!res) {
parser.reset();
return std::unexpected(res.error());
} }
return parser.release(); return parser.release();
} }
auto DataStoreBase::write(const json::value &value, QFileDevice &file) -> DataStoreResult<> auto DataStoreBase::write(const json::value &value, QFileDevice &file) -> DataStoreResult<>
{ {
qint64 nWritten;
m_serializer.reset(&value); m_serializer.reset(&value);
std::array<char, JSON_BUFSIZ> buf; std::array<char, JSON_BUFSIZ> buf;
while (!m_serializer.done()) { while (!m_serializer.done()) {
auto chunk = m_serializer.read(buf.data(), buf.size()); auto chunk = m_serializer.read(buf.data(), buf.size());
qint64 nWritten = file.write(chunk.data(), chunk.size()); nWritten = file.write(chunk.data(), chunk.size());
if (nWritten < 0) if (nWritten < 0)
return std::unexpected(&file); return std::unexpected(&file);
Q_ASSERT(nWritten == chunk.size()); Q_ASSERT(nWritten == chunk.size());
} }
// write trailing newline to make it a valid text file
nWritten = file.write("\n"_ba);
if (nWritten < 0)
return std::unexpected(&file);
Q_ASSERT(nWritten == 1);
if (!file.flush()) if (!file.flush())
return std::unexpected(&file); return std::unexpected(&file);

View File

@ -16,6 +16,7 @@
#include <filesystem> #include <filesystem>
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <system_error>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
@ -31,13 +32,15 @@ namespace gpt4all::ui {
class DataStoreError { class DataStoreError {
public: public:
using ErrorCode = std::variant< using ErrorCode = std::variant<
QFileDevice::FileError, std::monostate,
std::error_code,
boost::system::error_code, boost::system::error_code,
std::monostate QFileDevice::FileError
>; >;
DataStoreError(const QFileDevice *file); DataStoreError(std::error_code e);
DataStoreError(const boost::system::system_error &e); DataStoreError(const boost::system::system_error &e);
DataStoreError(const QFileDevice *file);
DataStoreError(QString e); DataStoreError(QString e);
[[nodiscard]] const ErrorCode &error () const { return m_error; } [[nodiscard]] const ErrorCode &error () const { return m_error; }
@ -94,8 +97,7 @@ public:
explicit DataStore(std::filesystem::path path); explicit DataStore(std::filesystem::path path);
auto list() -> tl::generator<const T &>; auto list() -> tl::generator<const T &>;
auto setData(T data) -> DataStoreResult<>; auto setData(T data, std::optional<QString> createName = {}) -> DataStoreResult<>;
auto createOrSetData(T data, const QString &name) -> DataStoreResult<>;
auto remove(const QUuid &id) -> DataStoreResult<>; auto remove(const QUuid &id) -> DataStoreResult<>;
auto acquire(QUuid id) -> DataStoreResult<std::optional<const T *>>; auto acquire(QUuid id) -> DataStoreResult<std::optional<const T *>>;

View File

@ -6,6 +6,8 @@
#include <QSaveFile> #include <QSaveFile>
#include <QtAssert> #include <QtAssert>
#include <system_error>
namespace gpt4all::ui { namespace gpt4all::ui {
@ -49,38 +51,19 @@ auto DataStore<T>::createImpl(T data, const QString &name) -> DataStoreResult<co
} }
template <typename T> template <typename T>
auto DataStore<T>::setData(T data) -> DataStoreResult<> auto DataStore<T>::setData(T data, std::optional<QString> createName) -> DataStoreResult<>
{ {
const QString *openName;
auto name_it = m_names.find(data.id); auto name_it = m_names.find(data.id);
if (name_it == m_names.end()) if (name_it != m_names.end()) {
openName = &name_it->second;
} else if (createName) {
openName = &*createName;
} else
return std::unexpected(QStringLiteral("id not found: %1").arg(data.id.toString())); return std::unexpected(QStringLiteral("id not found: %1").arg(data.id.toString()));
// acquire path // acquire path
auto file = openExisting(name_it->second); auto file = openExisting(*openName, !!createName);
if (!file)
return std::unexpected(file.error());
// serialize
if (auto res = write(boost::json::value_from(data), **file); !res)
return std::unexpected(res.error());
if (!(*file)->commit())
return std::unexpected(file->get());
// update
m_entries.at(data.id) = std::move(data);
return {};
}
template <typename T>
auto DataStore<T>::createOrSetData(T data, const QString &name) -> DataStoreResult<>
{
auto name_it = m_names.find(data.id);
if (name_it != m_names.end() && name_it->second != name)
return std::unexpected(QStringLiteral("name conflict for id %1: old=%2, new=%3")
.arg(data.id.toString(), name_it->second, name));
// acquire path
auto file = openExisting(name, /*allowCreate*/ true);
if (!file) if (!file)
return std::unexpected(file.error()); return std::unexpected(file.error());
@ -92,8 +75,19 @@ auto DataStore<T>::createOrSetData(T data, const QString &name) -> DataStoreResu
// update // update
m_entries[data.id] = std::move(data); m_entries[data.id] = std::move(data);
if (name_it == m_names.end())
m_names.emplace(data.id, name); // rename if necessary
if (name_it == m_names.end()) {
m_names.emplace(data.id, std::move(*createName));
} else if (*createName != name_it->second) {
std::error_code ec;
auto newPath = getFilePath(*createName);
std::filesystem::rename(getFilePath(name_it->second), newPath, ec);
if (ec)
return std::unexpected(ec);
m_names.at(data.id) = std::move(*createName);
}
return {}; return {};
} }