diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index bf987db8..d29af3ef 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -245,6 +245,7 @@ qt_add_executable(chat src/llmodel_ollama.cpp src/llmodel_ollama.h src/llmodel_openai.cpp src/llmodel_openai.h src/llmodel_provider.cpp src/llmodel_provider.h + src/llmodel_provider_builtins.cpp src/localdocs.cpp src/localdocs.h src/localdocsmodel.cpp src/localdocsmodel.h src/logger.cpp src/logger.h diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index b5491727..8eb4ba99 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -330,10 +330,8 @@ void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) emit trySwitchContextOfLoadedModelCompleted(0); } -// TODO: always call with a resource guard held since this didn't previously use coroutines auto ChatLLM::loadModel(const ModelInfo &modelInfo) -> QCoro::Task { - // TODO: get the description from somewhere bool alreadyAcquired = isModelLoaded(); if (alreadyAcquired && *modelInfo.modelDesc() == *m_modelInfo.modelDesc()) { // already acquired -> keep it diff --git a/gpt4all-chat/src/llmodel_chat.h b/gpt4all-chat/src/llmodel_chat.h index a44ec3c4..0e9ba155 100644 --- a/gpt4all-chat/src/llmodel_chat.h +++ b/gpt4all-chat/src/llmodel_chat.h @@ -19,7 +19,6 @@ struct ChatResponseMetadata { int nResponseTokens; }; -// TODO: implement two of these; one based on Ollama (TBD) and the other based on OpenAI (chatapi.h) class ChatLLMInstance { public: virtual ~ChatLLMInstance() noexcept = 0; diff --git a/gpt4all-chat/src/llmodel_description.h b/gpt4all-chat/src/llmodel_description.h index 56004a31..fc2811fb 100644 --- a/gpt4all-chat/src/llmodel_description.h +++ b/gpt4all-chat/src/llmodel_description.h @@ -15,12 +15,14 @@ namespace gpt4all::ui { class ChatLLMInstance; -// TODO: implement shared_from_this guidance for restricted construction class ModelDescription : public std::enable_shared_from_this { Q_GADGET Q_PROPERTY(const ModelProvider *provider READ provider CONSTANT) Q_PROPERTY(QVariant key READ key CONSTANT) +protected: + struct protected_t { explicit protected_t() = default; }; + public: virtual ~ModelDescription() noexcept = 0; diff --git a/gpt4all-chat/src/llmodel_ollama.cpp b/gpt4all-chat/src/llmodel_ollama.cpp index 56bea6aa..a05a1556 100644 --- a/gpt4all-chat/src/llmodel_ollama.cpp +++ b/gpt4all-chat/src/llmodel_ollama.cpp @@ -34,14 +34,17 @@ auto OllamaProvider::makeGenerationParams(const QMap { return new OllamaGenerationParams(values); } /// load -OllamaProviderCustom::OllamaProviderCustom(std::shared_ptr store, QUuid id) - : ModelProvider(std::move(id)) +OllamaProviderCustom::OllamaProviderCustom(std::shared_ptr store, QUuid id, QString name, QUrl baseUrl) + : ModelProvider (std::move(id), std::move(name), std::move(baseUrl)) , ModelProviderCustom(std::move(store)) -{ load(); } +{ + if (auto res = m_store->acquire(m_id); !res) + res.error().raise(); +} /// create OllamaProviderCustom::OllamaProviderCustom(std::shared_ptr store, QString name, QUrl baseUrl) - : ModelProvider(std::move(name), std::move(baseUrl)) + : ModelProvider (std::move(name), std::move(baseUrl)) , ModelProviderCustom(std::move(store)) { auto data = m_store->create(m_name, m_baseUrl); @@ -50,8 +53,20 @@ OllamaProviderCustom::OllamaProviderCustom(std::shared_ptr store, m_id = (*data)->id; } -OllamaModelDescription::OllamaModelDescription(std::shared_ptr provider, QByteArray modelHash) - : m_provider(std::move(provider)) +auto OllamaProviderCustom::asData() -> ModelProviderData +{ + return { + .id = m_id, + .builtin = false, + .type = ProviderType::ollama, + .custom_details = CustomProviderDetails { m_name, m_baseUrl }, + .provider_details = {}, + }; +} + +OllamaModelDescription::OllamaModelDescription(protected_t, std::shared_ptr provider, + QByteArray modelHash) + : m_provider (std::move(provider )) , m_modelHash(std::move(modelHash)) {} @@ -63,7 +78,7 @@ auto OllamaModelDescription::newInstanceImpl(QNetworkAccessManager *nam) const - OllamaChatModel::OllamaChatModel(std::shared_ptr description, QNetworkAccessManager *nam) : m_description(std::move(description)) - , m_nam(nam) + , m_nam (nam ) {} auto OllamaChatModel::preload() -> QCoro::Task<> diff --git a/gpt4all-chat/src/llmodel_ollama.h b/gpt4all-chat/src/llmodel_ollama.h index 761b4db0..bfd94334 100644 --- a/gpt4all-chat/src/llmodel_ollama.h +++ b/gpt4all-chat/src/llmodel_ollama.h @@ -50,8 +50,10 @@ public: auto makeGenerationParams(const QMap &values) const -> OllamaGenerationParams * override; }; -class OllamaProviderBuiltin : public OllamaProvider, public ModelProviderBuiltin { +class OllamaProviderBuiltin : public OllamaProvider { Q_GADGET + Q_PROPERTY(QString name READ name CONSTANT) + Q_PROPERTY(QUrl baseUrl READ baseUrl CONSTANT) public: /// Create a new built-in Ollama provider (transient). @@ -64,18 +66,17 @@ class OllamaProviderCustom final : public OllamaProvider, public ModelProviderCu public: /// Load an existing OllamaProvider from disk. - explicit OllamaProviderCustom(std::shared_ptr store, QUuid id); + explicit OllamaProviderCustom(ProviderStore *store, QUuid id, QString name, QUrl baseUrl); /// Create a new OllamaProvider on disk. - explicit OllamaProviderCustom(std::shared_ptr store, QString name, QUrl baseUrl); + explicit OllamaProviderCustom(ProviderStore *store, QString name, QUrl baseUrl); Q_SIGNALS: void nameChanged (const QString &value); void baseUrlChanged(const QUrl &value); protected: - auto asData() -> ModelProviderData override - { return { m_id, ProviderType::ollama, m_name, m_baseUrl, {} }; } + auto asData() -> ModelProviderData override; }; class OllamaModelDescription : public ModelDescription { @@ -83,7 +84,11 @@ class OllamaModelDescription : public ModelDescription { Q_PROPERTY(QByteArray modelHash READ modelHash CONSTANT) public: - explicit OllamaModelDescription(std::shared_ptr provider, QByteArray modelHash); + explicit OllamaModelDescription(protected_t, std::shared_ptr provider, QByteArray modelHash); + + static auto create(std::shared_ptr provider, QByteArray modelHash) + -> std::shared_ptr + { return std::make_shared(protected_t(), std::move(provider), std::move(modelHash)); } // getters [[nodiscard]] auto provider () const -> const OllamaProvider * override { return m_provider.get(); } diff --git a/gpt4all-chat/src/llmodel_openai.cpp b/gpt4all-chat/src/llmodel_openai.cpp index 53b51e85..d47dfb43 100644 --- a/gpt4all-chat/src/llmodel_openai.cpp +++ b/gpt4all-chat/src/llmodel_openai.cpp @@ -96,23 +96,38 @@ auto OpenaiProvider::makeGenerationParams(const QMap -> OpenaiGenerationParams * { return new OpenaiGenerationParams(values); } -OpenaiProviderBuiltin::OpenaiProviderBuiltin(QUuid id, QString name, QUrl baseUrl, QString apiKey) +OpenaiProviderBuiltin::OpenaiProviderBuiltin(ProviderStore *store, QUuid id, QString name, QUrl baseUrl) : ModelProvider(std::move(id), std::move(name), std::move(baseUrl)) - , OpenaiProvider(std::move(apiKey)) - {} - -/// load -OpenaiProviderCustom::OpenaiProviderCustom(std::shared_ptr store, QUuid id) - : ModelProvider(std::move(id)) - , ModelProviderCustom(std::move(store)) + , ModelProviderMutable(store) { - auto &details = load(); - m_apiKey = std::get(details).api_key; + auto res = m_store->acquire(m_id); + if (!res) + res.error().raise(); + if (auto maybeData = *res) { + auto &details = (*maybeData)->openai_details.value(); + m_apiKey = details.api_key; + } } +auto OpenaiProviderBuiltin::asData() -> ModelProviderData +{ + return { + .id = m_id, + .type = ProviderType::openai, + .custom_details = {}, + .openai_details = OpenaiProviderDetails { m_apiKey }, + }; +} + +/// load +OpenaiProviderCustom::OpenaiProviderCustom(ProviderStore *store, QUuid id, QString name, QUrl baseUrl, QString apiKey) + : ModelProvider(std::move(id), std::move(name), std::move(baseUrl)) + , OpenaiProvider(std::move(apiKey)) + , ModelProviderCustom(store) + {} + /// create -OpenaiProviderCustom::OpenaiProviderCustom(std::shared_ptr store, QString name, QUrl baseUrl, - QString apiKey) +OpenaiProviderCustom::OpenaiProviderCustom(ProviderStore *store, QString name, QUrl baseUrl, QString apiKey) : ModelProvider(std::move(name), std::move(baseUrl)) , ModelProviderCustom(std::move(store)) , OpenaiProvider(std::move(apiKey)) @@ -123,8 +138,19 @@ OpenaiProviderCustom::OpenaiProviderCustom(std::shared_ptr store, m_id = (*data)->id; } -OpenaiModelDescription::OpenaiModelDescription(std::shared_ptr provider, QString modelName) - : m_provider(std::move(provider)) +auto OpenaiProviderCustom::asData() -> ModelProviderData +{ + return { + .id = m_id, + .type = ProviderType::openai, + .custom_details = CustomProviderDetails { m_name, m_baseUrl }, + .openai_details = OpenaiProviderDetails { m_apiKey }, + }; +} + +OpenaiModelDescription::OpenaiModelDescription(protected_t, std::shared_ptr provider, + QString modelName) + : m_provider (std::move(provider )) , m_modelName(std::move(modelName)) {} diff --git a/gpt4all-chat/src/llmodel_openai.h b/gpt4all-chat/src/llmodel_openai.h index 28c2cee9..a5a9fa6e 100644 --- a/gpt4all-chat/src/llmodel_openai.h +++ b/gpt4all-chat/src/llmodel_openai.h @@ -42,12 +42,13 @@ protected: class OpenaiProvider : public QObject, public virtual ModelProvider { Q_OBJECT + Q_PROPERTY(QString apiKey READ apiKey WRITE setApiKey NOTIFY apiKeyChanged) protected: - explicit OpenaiProvider() = default; // custom - explicit OpenaiProvider(QString apiKey) // built-in - : m_apiKey(std::move(apiKey)) - {} + explicit OpenaiProvider() = default; + explicit OpenaiProvider(QString apiKey) + : m_apiKey(std::move(apiKey)) {} + public: ~OpenaiProvider() noexcept override = 0; @@ -56,6 +57,8 @@ public: [[nodiscard]] const QString &apiKey() const { return m_apiKey; } + virtual void setApiKey(QString value) = 0; + auto supportedGenerationParams() const -> QSet override; auto makeGenerationParams(const QMap &values) const -> OpenaiGenerationParams * override; @@ -63,28 +66,32 @@ protected: QString m_apiKey; }; -class OpenaiProviderBuiltin : public OpenaiProvider, public ModelProviderBuiltin { +class OpenaiProviderBuiltin : public OpenaiProvider, private ModelProviderMutable { Q_GADGET - Q_PROPERTY(QString apiKey READ apiKey CONSTANT) + Q_PROPERTY(QString name READ name CONSTANT) + Q_PROPERTY(QUrl baseUrl READ baseUrl CONSTANT) public: - /// Create a new built-in OpenAI provider (transient). - explicit OpenaiProviderBuiltin(QUuid id, QString name, QUrl baseUrl, QString apiKey); + /// Create a new built-in OpenAI provider, loading its API key from disk if known. + explicit OpenaiProviderBuiltin(ProviderStore *store, QUuid id, QString name, QUrl baseUrl); + + void setApiKey(QString value) override { setMemberProp(&OpenaiProviderBuiltin::m_apiKey, "apiKey", std::move(value)); } + +protected: + auto asData() -> ModelProviderData override; }; class OpenaiProviderCustom final : public OpenaiProvider, public ModelProviderCustom { Q_OBJECT - Q_PROPERTY(QString apiKey READ apiKey WRITE setApiKey NOTIFY apiKeyChanged) - public: /// Load an existing OpenaiProvider from disk. - explicit OpenaiProviderCustom(std::shared_ptr store, QUuid id); + explicit OpenaiProviderCustom(ProviderStore *store, QUuid id, QString name, QUrl baseUrl, QString apiKey); /// Create a new OpenaiProvider on disk. - explicit OpenaiProviderCustom(std::shared_ptr store, QString name, QUrl baseUrl, QString apiKey); + explicit OpenaiProviderCustom(ProviderStore *store, QString name, QUrl baseUrl, QString apiKey); - void setApiKey(QString value) { setMemberProp(&OpenaiProviderCustom::m_apiKey, "apiKey", std::move(value)); } + void setApiKey(QString value) override { setMemberProp(&OpenaiProviderCustom::m_apiKey, "apiKey", std::move(value)); } Q_SIGNALS: void nameChanged (const QString &value); @@ -92,8 +99,7 @@ Q_SIGNALS: void apiKeyChanged (const QString &value); protected: - auto asData() -> ModelProviderData override - { return { m_id, ProviderType::openai, m_name, m_baseUrl, OpenaiProviderDetails { m_apiKey } }; } + auto asData() -> ModelProviderData override; }; class OpenaiModelDescription : public ModelDescription { @@ -101,7 +107,11 @@ class OpenaiModelDescription : public ModelDescription { Q_PROPERTY(QString modelName READ modelName CONSTANT) public: - explicit OpenaiModelDescription(std::shared_ptr provider, QString modelName); + explicit OpenaiModelDescription(protected_t, std::shared_ptr provider, QString modelName); + + static auto create(std::shared_ptr provider, QByteArray modelHash) + -> std::shared_ptr + { return std::make_shared(protected_t(), std::move(provider), std::move(modelHash)); } // getters [[nodiscard]] auto provider () const -> const OpenaiProvider * override { return m_provider.get(); } diff --git a/gpt4all-chat/src/llmodel_provider.cpp b/gpt4all-chat/src/llmodel_provider.cpp index ae88c3a5..fbf847ef 100644 --- a/gpt4all-chat/src/llmodel_provider.cpp +++ b/gpt4all-chat/src/llmodel_provider.cpp @@ -1,5 +1,8 @@ #include "llmodel_provider.h" +#include "llmodel_ollama.h" +#include "llmodel_openai.h" + #include "mysettings.h" #include @@ -42,67 +45,97 @@ QVariant GenerationParams::tryParseValue(QMap &values ModelProvider::~ModelProvider() noexcept = default; -ModelProviderCustom::~ModelProviderCustom() noexcept +ModelProviderMutable::~ModelProviderMutable() noexcept { if (auto res = m_store->release(m_id); !res) res.error().raise(); // should not happen - will terminate program } -auto ModelProviderCustom::load() -> const ModelProviderData::Details & -{ - auto data = m_store->acquire(m_id); - if (!data) - data.error().raise(); - m_name = (*data)->name; - m_baseUrl = (*data)->base_url; - return (*data)->details; -} - -ProviderRegistry::ProviderRegistry(fs::path path) - : m_store(std::move(path)) +ProviderRegistry::ProviderRegistry(PathSet paths) + : m_customStore (std::move(paths.custom )) + , m_builtinStore(std::move(paths.builtin)) { auto *mysettings = MySettings::globalInstance(); connect(mysettings, &MySettings::modelPathChanged, this, &ProviderRegistry::onModelPathChanged); + load(); } -Q_INVOKABLE void ProviderRegistry::registerBuiltinProvider(ModelProviderBuiltin *provider) +void ProviderRegistry::load() { - auto [_, unique] = m_providers.emplace(provider->id(), provider->asQObject()); - if (!unique) - qWarning() << "ignoring duplicate provider:" << provider->id(); + for (auto &p : s_builtinProviders) { // (not all builtin providers are stored) + auto provider = std::make_shared(m_builtinStore, p.id, p.name, p.base_url); + auto [_, unique] = m_providers.emplace(p.id, std::move(provider)); + if (!unique) + throw std::logic_error(fmt::format("duplicate builtin provider id: {}", p.id)); + } + for (auto &p : m_customStore.list()) { // disk is source of truth for custom providers + if (!p.custom_details) { + qWarning() << "ignoring builtin provider in custom store:" << p.id; + continue; + } + auto &cust = *p.custom_details; + std::shared_ptr provider; + switch (p.type) { + using enum ProviderType; + case ollama: + provider = std::make_shared( + &m_customStore, p.id, cust.name, cust.base_url + ); + case openai: + provider = std::make_shared( + &m_customStore, p.id, cust.name, cust.base_url, p.openai_details.value().api_key + ); + } + auto [_, unique] = m_providers.emplace(p.id, std::move(provider)); + if (!unique) + qWarning() << "ignoring duplicate custom provider with id:" << p.id; + } } [[nodiscard]] -bool ProviderRegistry::registerCustomProvider(std::unique_ptr provider) +bool ProviderRegistry::add(std::unique_ptr provider) { - auto [_, unique] = m_providers.emplace(provider->id(), provider->asQObject()); + auto [it, unique] = m_providers.emplace(provider->id(), std::move(provider)); if (unique) { - m_customProviders.push_back(std::move(provider)); + m_customProviders.push_back(std::make_unique(it->first)); emit customProviderAdded(m_customProviders.size() - 1); } return unique; } -fs::path ProviderRegistry::getSubdir() +auto ProviderRegistry::customProviderAt(size_t i) const -> const ModelProviderCustom * +{ + auto it = m_providers.find(*m_customProviders.at(i)); + Q_ASSERT(it != m_providers.end()); + return &dynamic_cast(*it->second); +} + +auto ProviderRegistry::getSubdirs() -> PathSet { auto *mysettings = MySettings::globalInstance(); - return toFSPath(mysettings->modelPath()) / "providers"; + auto parent = toFSPath(mysettings->modelPath()) / "providers"; + return { .builtin = parent, .custom = parent / "custom" }; } void ProviderRegistry::onModelPathChanged() { - auto path = getSubdir(); - if (path != m_store.path()) { + auto paths = getSubdirs(); + if (paths.builtin != m_builtinStore.path()) { emit aboutToBeCleared(); - m_customProviders.clear(); // delete custom providers to release store locks - if (auto res = m_store.setPath(path); !res) + // delete providers to release store locks + m_customProviders.clear(); + m_providers.clear(); + if (auto res = m_builtinStore.setPath(paths.builtin); !res) res.error().raise(); // should not happen + if (auto res = m_customStore.setPath(paths.custom); !res) + res.error().raise(); // should not happen + load(); } } CustomProviderList::CustomProviderList(QPointer registry) - : m_registry(std::move(registry)) - , m_size(m_registry->customProviderCount()) + : m_registry(std::move(registry) ) + , m_size (m_registry->customProviderCount()) { connect(m_registry, &ProviderRegistry::customProviderAdded, this, &CustomProviderList::onCustomProviderAdded); connect(m_registry, &ProviderRegistry::aboutToBeCleared, this, &CustomProviderList::onAboutToBeCleared, diff --git a/gpt4all-chat/src/llmodel_provider.h b/gpt4all-chat/src/llmodel_provider.h index a4aea566..6ff6b9ef 100644 --- a/gpt4all-chat/src/llmodel_provider.h +++ b/gpt4all-chat/src/llmodel_provider.h @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -63,9 +64,7 @@ class ModelProvider { Q_PROPERTY(QUuid id READ id CONSTANT) protected: - explicit ModelProvider(QUuid id) // load - : m_id(std::move(id)) {} - explicit ModelProvider(QUuid id, QString name, QUrl baseUrl) // create built-in + explicit ModelProvider(QUuid id, QString name, QUrl baseUrl) // create built-in or load : m_id(std::move(id)), m_name(std::move(name)), m_baseUrl(std::move(baseUrl)) {} explicit ModelProvider(QString name, QUrl baseUrl) // create custom : m_name(std::move(name)), m_baseUrl(std::move(baseUrl)) {} @@ -93,36 +92,39 @@ protected: QUrl m_baseUrl; }; -class ModelProviderBuiltin : public virtual ModelProvider { +// Mixin with no public interface providing basic load/save +class ModelProviderMutable : public virtual ModelProvider { Q_GADGET - Q_PROPERTY(QString name READ name CONSTANT) - Q_PROPERTY(QUrl baseUrl READ baseUrl CONSTANT) -}; - -class ModelProviderCustom : public virtual ModelProvider { - Q_GADGET - Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged ) - Q_PROPERTY(QUrl baseUrl READ baseUrl WRITE setBaseUrl NOTIFY baseUrlChanged) protected: - explicit ModelProviderCustom(std::shared_ptr store) - : m_store(std::move(store)) {} + explicit ModelProviderMutable(ProviderStore *store) + : m_store(store) {} public: - ~ModelProviderCustom() noexcept override; - - // setters - void setName (QString value) { setMemberProp(&ModelProviderCustom::m_name, "name", std::move(value)); } - void setBaseUrl(QUrl value) { setMemberProp(&ModelProviderCustom::m_baseUrl, "baseUrl", std::move(value)); } + ~ModelProviderMutable() noexcept override; protected: - virtual auto load() -> const ModelProviderData::Details &; virtual auto asData() -> ModelProviderData = 0; template void setMemberProp(this S &self, T C::* member, std::string_view name, T value); - std::shared_ptr m_store; + ProviderStore *m_store; +}; + +class ModelProviderCustom : public ModelProviderMutable { + Q_GADGET + Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged ) + Q_PROPERTY(QUrl baseUrl READ baseUrl WRITE setBaseUrl NOTIFY baseUrlChanged) + +protected: + explicit ModelProviderCustom(ProviderStore *store) + : ModelProviderMutable(store) {} + +public: + // setters + void setName (QString value) { setMemberProp(&ModelProviderCustom::m_name, "name", std::move(value)); } + void setBaseUrl(QUrl value) { setMemberProp(&ModelProviderCustom::m_baseUrl, "baseUrl", std::move(value)); } }; class ProviderRegistry : public QObject { @@ -130,18 +132,26 @@ class ProviderRegistry : public QObject { QML_ELEMENT QML_SINGLETON +private: + struct PathSet { std::filesystem::path builtin, custom; }; + + struct BuiltinProviderData { + QUuid id; + QString name; + QUrl base_url; + }; + protected: - explicit ProviderRegistry(std::filesystem::path path); + explicit ProviderRegistry(PathSet paths); public: - static ProviderRegistry *create(QQmlEngine *, QJSEngine *) { return new ProviderRegistry(getSubdir()); } - Q_INVOKABLE void registerBuiltinProvider(ModelProviderBuiltin *provider); - [[nodiscard]] bool registerCustomProvider (std::unique_ptr provider); + static ProviderRegistry *create(QQmlEngine *, QJSEngine *) { return new ProviderRegistry(getSubdirs()); } + [[nodiscard]] bool add(std::unique_ptr provider); - size_t customProviderCount() const + // TODO(jared): implement a way to remove custom providers via the model + [[nodiscard]] size_t customProviderCount() const { return m_customProviders.size(); } - auto customProviderAt(size_t i) const -> const ModelProviderCustom * - { return m_customProviders.at(i).get(); } + [[nodiscard]] auto customProviderAt(size_t i) const -> const ModelProviderCustom *; auto operator[](const QUuid &id) -> ModelProviderCustom * { return &dynamic_cast(*m_providers.at(id)); } @@ -150,20 +160,24 @@ Q_SIGNALS: void aboutToBeCleared(); private: - static auto getSubdir() -> std::filesystem::path; + void load(); + static PathSet getSubdirs(); private Q_SLOTS: void onModelPathChanged(); private: - ProviderStore m_store; - std::unordered_map> m_providers; - std::vector> m_customProviders; + static constexpr size_t N_BUILTIN = 3; + static const std::array s_builtinProviders; + + ProviderStore m_customStore; + ProviderStore m_builtinStore; + std::unordered_map> m_providers; + std::vector> m_customProviders; }; class CustomProviderList : public QAbstractListModel { Q_OBJECT - QML_ELEMENT protected: explicit CustomProviderList(QPointer registry); diff --git a/gpt4all-chat/src/llmodel_provider.inl b/gpt4all-chat/src/llmodel_provider.inl index a707f6b2..94a078c3 100644 --- a/gpt4all-chat/src/llmodel_provider.inl +++ b/gpt4all-chat/src/llmodel_provider.inl @@ -13,9 +13,9 @@ void GenerationParams::tryParseValue(this S &self, QMap -void ModelProviderCustom::setMemberProp(this S &self, T C::* member, std::string_view name, T value) +void ModelProviderMutable::setMemberProp(this S &self, T C::* member, std::string_view name, T value) { - auto &mpc = static_cast(self); + auto &mpc = static_cast(self); auto &cur = self.*member; if (cur != value) { cur = std::move(value); diff --git a/gpt4all-chat/src/llmodel_provider_builtins.cpp b/gpt4all-chat/src/llmodel_provider_builtins.cpp new file mode 100644 index 00000000..77367ee1 --- /dev/null +++ b/gpt4all-chat/src/llmodel_provider_builtins.cpp @@ -0,0 +1,34 @@ +#include "llmodel_provider.h" + +using namespace Qt::StringLiterals; + + +namespace gpt4all::ui { + + +// TODO: use these in the constructor of ProviderRegistry +// TODO: we have to be careful to reserve these names for ProviderStore purposes, so the user can't write JSON files that alias them. +// this *is a problem*, because we want to be able to safely introduce these. +// so we need a different namespace, i.e. a *different directory*. +const std::array< + ProviderRegistry::BuiltinProviderData, ProviderRegistry::N_BUILTIN +> ProviderRegistry::s_builtinProviders { + BuiltinProviderData { + .id = QUuid("20f963dc-1f99-441e-ad80-f30a0a06bcac"), + .name = u"Groq"_s, + .base_url = u"https://api.groq.com/openai/v1/"_s, + }, + BuiltinProviderData { + .id = QUuid("6f874c3a-f1ad-47f7-9129-755c5477146c"), + .name = u"OpenAI"_s, + .base_url = u"https://api.openai.com/v1/"_s, + }, + BuiltinProviderData { + .id = QUuid("7ae617b3-c0b2-4d2c-9ff2-bc3f049494cc"), + .name = u"Mistral"_s, + .base_url = u"https://api.mistral.ai/v1/"_s, + }, +}; + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/store_base.cpp b/gpt4all-chat/src/store_base.cpp index f3d49eb6..c084e605 100644 --- a/gpt4all-chat/src/store_base.cpp +++ b/gpt4all-chat/src/store_base.cpp @@ -73,7 +73,7 @@ auto DataStoreBase::reload() -> DataStoreResult<> if (!jv) { (qWarning().nospace() << "skipping " << file.fileName() << "because of read error: ").noquote() << jv.error().errorString(); - } else if (auto [unique, uuid] = insert(*jv); !unique) + } else if (auto [unique, uuid] = cacheInsert(*jv); !unique) qWarning() << "skipping duplicate data store entry:" << uuid; file.close(); } @@ -103,7 +103,7 @@ auto DataStoreBase::openNew(const QString &name) -> DataStoreResult DataStoreResult> +auto DataStoreBase::openExisting(const QString &name, bool allowCreate) -> DataStoreResult> { auto path = getFilePath(name); if (!QFile::exists(path)) @@ -111,7 +111,10 @@ auto DataStoreBase::openExisting(const QString &name) -> DataStoreResult(toQString(path)); - if (!file->open(QSaveFile::WriteOnly | QSaveFile::ExistingOnly)) + QFile::OpenMode flags = QSaveFile::WriteOnly; + if (!allowCreate) + flags |= QSaveFile::ExistingOnly; + if (!file->open(flags)) return std::unexpected(&*file); return file; } diff --git a/gpt4all-chat/src/store_base.h b/gpt4all-chat/src/store_base.h index 79ba2300..720fe797 100644 --- a/gpt4all-chat/src/store_base.h +++ b/gpt4all-chat/src/store_base.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -65,13 +66,13 @@ public: protected: auto reload() -> DataStoreResult<>; virtual auto clear() -> DataStoreResult<> = 0; - struct InsertResult { bool unique; QUuid uuid; }; - virtual InsertResult insert(const boost::json::value &jv) = 0; + struct CacheInsertResult { bool unique; QUuid uuid; }; + virtual CacheInsertResult cacheInsert(const boost::json::value &jv) = 0; // helpers auto getFilePath(const QString &name) -> std::filesystem::path; auto openNew(const QString &name) -> DataStoreResult>; - auto openExisting(const QString &name) -> DataStoreResult>; + auto openExisting(const QString &name, bool allowCreate = false) -> DataStoreResult>; static auto read(QFileDevice &file, boost::json::stream_parser &parser) -> DataStoreResult; auto write(const boost::json::value &value, QFileDevice &file) -> DataStoreResult<>; @@ -94,23 +95,26 @@ public: auto list() -> tl::generator; auto setData(T data) -> DataStoreResult<>; + auto createOrSetData(T data, const QString &name) -> DataStoreResult<>; auto remove(const QUuid &id) -> DataStoreResult<>; - auto acquire(QUuid id) -> DataStoreResult; + auto acquire(QUuid id) -> DataStoreResult>; auto release(const QUuid &id) -> DataStoreResult<>; - [[nodiscard]] - auto operator[](const QUuid &id) const -> const T & + [[nodiscard]] auto operator[](const QUuid &id) const -> const T & { return m_entries.at(id); } + [[nodiscard]] auto find(const QUuid &id) const -> std::optional + { auto it = m_entries.find(id); return it == m_entries.end() ? std::nullopt : std::optional(&*it); } protected: auto createImpl(T data, const QString &name) -> DataStoreResult; auto clear() -> DataStoreResult<> final; - InsertResult insert(const boost::json::value &jv) override; + CacheInsertResult cacheInsert(const boost::json::value &jv) override; private: - std::unordered_map m_entries; - std::unordered_set m_acquired; + std::unordered_map m_entries; + std::unordered_set m_acquired; + std::unordered_map m_names; }; diff --git a/gpt4all-chat/src/store_base.inl b/gpt4all-chat/src/store_base.inl index 840afd0d..641ad905 100644 --- a/gpt4all-chat/src/store_base.inl +++ b/gpt4all-chat/src/store_base.inl @@ -51,8 +51,12 @@ auto DataStore::createImpl(T data, const QString &name) -> DataStoreResult auto DataStore::setData(T data) -> DataStoreResult<> { + auto name_it = m_names.find(data.id); + if (name_it == m_names.end()) + return std::unexpected(QStringLiteral("id not found: %1").arg(data.id.toString())); + // acquire path - auto file = openExisting(data.name); + auto file = openExisting(name_it->second); if (!file) return std::unexpected(file.error()); @@ -67,6 +71,32 @@ auto DataStore::setData(T data) -> DataStoreResult<> return {}; } +template +auto DataStore::createOrSetData(T data, const QString &name) -> DataStoreResult<> +{ + auto name_it = m_names.find(data.id); + if (name_it != m_names.end() && name_it->second != name) + return std::unexpected(QStringLiteral("name conflict for id %1: old=%2, new=%3") + .arg(data.id.toString(), name_it->second, name)); + + // acquire path + auto file = openExisting(name, /*allowCreate*/ true); + if (!file) + return std::unexpected(file.error()); + + // serialize + if (auto res = write(boost::json::value_from(data), **file); !res) + return std::unexpected(res.error()); + if (!(*file)->commit()) + return std::unexpected(file->get()); + + // update + m_entries[data.id] = std::move(data); + if (name_it == m_names.end()) + m_names.emplace(data.id, name); + return {}; +} + template auto DataStore::remove(const QUuid &id) -> DataStoreResult<> { @@ -89,12 +119,12 @@ auto DataStore::remove(const QUuid &id) -> DataStoreResult<> } template -auto DataStore::acquire(QUuid id) -> DataStoreResult +auto DataStore::acquire(QUuid id) -> DataStoreResult> { auto [it, unique] = m_acquired.insert(std::move(id)); if (!unique) return std::unexpected(QStringLiteral("id already acquired: %1").arg(id.toString())); - return &(*this)[*it]; + return find(*it); } template @@ -115,7 +145,7 @@ auto DataStore::clear() -> DataStoreResult<> } template -auto DataStore::insert(const boost::json::value &jv) -> InsertResult +auto DataStore::cacheInsert(const boost::json::value &jv) -> CacheInsertResult { auto data = boost::json::value_to(jv); auto id = data.id; @@ -124,5 +154,4 @@ auto DataStore::insert(const boost::json::value &jv) -> InsertResult } - } // namespace gpt4all::ui diff --git a/gpt4all-chat/src/store_provider.cpp b/gpt4all-chat/src/store_provider.cpp index 6c42a437..30a26aed 100644 --- a/gpt4all-chat/src/store_provider.cpp +++ b/gpt4all-chat/src/store_provider.cpp @@ -1,23 +1,81 @@ #include "store_provider.h" +#include "json-helpers.h" // IWYU pragma: keep + +#include // IWYU pragma: keep + #include +namespace json = boost::json; + namespace gpt4all::ui { +void tag_invoke(const boost::json::value_from_tag &, boost::json::value &jv, ModelProviderData data) +{ + auto &obj = jv.emplace_object(); + obj = { { "id", data.id }, + { "builtin", !data.custom_details }, + { "type", data.type() } }; + if (auto custom = data.custom_details) { + obj.emplace("name", custom->name); + obj.emplace("base_url", custom->base_url); + } + switch (data.type()) { + using enum ProviderType; + case openai: + obj.emplace("api_key", std::get(data.provider_details).api_key); + case ollama: + ; + } +} + +auto tag_invoke(const boost::json::value_to_tag &, const boost::json::value &jv) + -> ModelProviderData +{ + auto &obj = jv.as_object(); + auto type = json::value_to(jv.at("type")); + std::optional custom_details; + if (!jv.at("builtin").as_bool()) + custom_details.emplace(CustomProviderDetails { + json::value_to(jv.at("name" )), + json::value_to(jv.at("base_url")), + }); + ModelProviderData::ProviderDetails provider_details; + switch (type) { + using enum ProviderType; + case openai: + provider_details = OpenaiProviderDetails { json::value_to(jv.at("api_key")) }; + case ollama: + ; + } + return { + .id = json::value_to(obj.at("id")), + .custom_details = std::move(custom_details), + .provider_details = std::move(provider_details) + }; +} + auto ProviderStore::create(QString name, QUrl base_url, QString api_key) -> DataStoreResult { - ModelProviderData data { QUuid::createUuid(), ProviderType::openai, name, std::move(base_url), - OpenaiProviderDetails { std::move(api_key) } }; + ModelProviderData data { + .id = QUuid::createUuid(), + .custom_details = CustomProviderDetails { name, std::move(base_url) }, + .provider_details = OpenaiProviderDetails { std::move(api_key) }, + }; return createImpl(std::move(data), name); } auto ProviderStore::create(QString name, QUrl base_url) -> DataStoreResult { - ModelProviderData data { QUuid::createUuid(), ProviderType::ollama, name, std::move(base_url), {} }; + ModelProviderData data { + .id = QUuid::createUuid(), + .custom_details = CustomProviderDetails { name, std::move(base_url) }, + .provider_details = {}, + }; return createImpl(std::move(data), name); } diff --git a/gpt4all-chat/src/store_provider.h b/gpt4all-chat/src/store_provider.h index 02666a10..078a3ec4 100644 --- a/gpt4all-chat/src/store_provider.h +++ b/gpt4all-chat/src/store_provider.h @@ -4,33 +4,43 @@ #include #include +#include // IWYU pragma: keep #include #include #include -#include - namespace gpt4all::ui { -BOOST_DEFINE_ENUM_CLASS(ProviderType, openai, ollama) +// indices of this enum should be consistent with indices of ProviderDetails +enum class ProviderType { + openai = 0, + ollama = 1, +}; +BOOST_DESCRIBE_ENUM(ProviderType, openai, ollama) + +struct CustomProviderDetails { + QString name; + QUrl base_url; +}; struct OpenaiProviderDetails { QString api_key; }; -BOOST_DESCRIBE_STRUCT(OpenaiProviderDetails, (), (api_key)) struct ModelProviderData { - using Details = std::variant; - QUuid id; - ProviderType type; - QString name; - QUrl base_url; - Details details; + using ProviderDetails = std::variant; + QUuid id; + std::optional custom_details; + ProviderDetails provider_details; + + ProviderType type() const { return ProviderType(provider_details.index()); } }; -BOOST_DESCRIBE_STRUCT(ModelProviderData, (), (id, type, name, base_url, details)) +void tag_invoke(const boost::json::value_from_tag &, boost::json::value &jv, ModelProviderData data); +auto tag_invoke(const boost::json::value_to_tag &, const boost::json::value &jv) + -> ModelProviderData; class ProviderStore : public DataStore { private: