diff --git a/gpt4all-backend-test/src/main.cpp b/gpt4all-backend-test/src/main.cpp index 559e26af..6cd2cefc 100644 --- a/gpt4all-backend-test/src/main.cpp +++ b/gpt4all-backend-test/src/main.cpp @@ -21,13 +21,25 @@ static void run() { fmt::print("Connecting to server at {}\n", OLLAMA_URL); OllamaClient provider(OLLAMA_URL); - auto version = QCoro::waitFor(provider.getVersion()); - if (version) { - fmt::print("Server version: {}\n", version->version); + + auto versionResp = QCoro::waitFor(provider.getVersion()); + if (versionResp) { + fmt::print("Server version: {}\n", versionResp->version); } else { - fmt::print("Error retrieving version: {}\n", version.error().errorString); + fmt::print("Error retrieving version: {}\n", versionResp.error().errorString); return QCoreApplication::exit(1); } + + auto modelsResponse = QCoro::waitFor(provider.listModels()); + if (modelsResponse) { + fmt::print("Available models:\n"); + for (const auto & model : modelsResponse->models) + fmt::print("{}\n", model.model); + } else { + fmt::print("Error retrieving version: {}\n", modelsResponse.error().errorString); + return QCoreApplication::exit(1); + } + QCoreApplication::exit(0); } diff --git a/gpt4all-backend/CMakeLists.txt b/gpt4all-backend/CMakeLists.txt index fcc46726..e38a8a51 100644 --- a/gpt4all-backend/CMakeLists.txt +++ b/gpt4all-backend/CMakeLists.txt @@ -14,4 +14,5 @@ target_sources(gpt4all-backend PUBLIC FILE_SET public_headers TYPE HEADERS BASE_DIRS include FILES include/gpt4all-backend/formatters.h include/gpt4all-backend/ollama_client.h + include/gpt4all-backend/ollama_responses.h ) diff --git a/gpt4all-backend/include/gpt4all-backend/ollama_client.h b/gpt4all-backend/include/gpt4all-backend/ollama_client.h index f1773b23..6982cfba 100644 --- a/gpt4all-backend/include/gpt4all-backend/ollama_client.h +++ b/gpt4all-backend/include/gpt4all-backend/ollama_client.h @@ -1,7 +1,8 @@ #pragma once +#include "ollama_responses.h" + #include // IWYU pragma: keep -#include #include #include @@ -13,6 +14,8 @@ #include #include +namespace boost::json { class value; } + namespace gpt4all::backend { @@ -45,9 +48,6 @@ public: template using DataOrRespErr = std::expected; -struct VersionResponse { QString version; }; -BOOST_DESCRIBE_STRUCT(VersionResponse, (), (version)) - class OllamaClient { public: OllamaClient(QUrl baseUrl) @@ -57,12 +57,26 @@ public: const QUrl &baseUrl() const { return m_baseUrl; } void getBaseUrl(QUrl value) { m_baseUrl = std::move(value); } - /// Retrieve the Ollama version, e.g. "0.5.1" - auto getVersion() -> QCoro::Task>; + /// Returns the version of the Ollama server. + auto getVersion() -> QCoro::Task> + { return getSimple(QStringLiteral("version")); } + + /// List models that are available locally. + auto listModels() -> QCoro::Task> + { return getSimple(QStringLiteral("tags")); } + +private: + template + auto getSimple(const QString &endpoint) -> QCoro::Task>; + + auto getSimpleGeneric(const QString &endpoint) -> QCoro::Task>; private: QUrl m_baseUrl; QNetworkAccessManager m_nam; }; +extern template auto OllamaClient::getSimple(const QString &) -> QCoro::Task>; +extern template auto OllamaClient::getSimple(const QString &) -> QCoro::Task>; + } // namespace gpt4all::backend diff --git a/gpt4all-backend/include/gpt4all-backend/ollama_responses.h b/gpt4all-backend/include/gpt4all-backend/ollama_responses.h new file mode 100644 index 00000000..69e75207 --- /dev/null +++ b/gpt4all-backend/include/gpt4all-backend/ollama_responses.h @@ -0,0 +1,47 @@ +#pragma once + +#include + +#include +#include + +#include + + +namespace gpt4all::backend::ollama { + +/// Details about a model. +struct ModelDetails { + QString parent_model; /// The parent of the model. + QString format; /// The format of the model. + QString family; /// The family of the model. + std::vector families; /// The families of the model. + QString parameter_size; /// The size of the model's parameters. + QString quantization_level; /// The quantization level of the model. +}; +BOOST_DESCRIBE_STRUCT(ModelDetails, (), (parent_model, format, family, families, parameter_size, quantization_level)) + +/// A model available locally. +struct Model { + QString model; /// The model name. + QString modified_at; /// Model modification date. + quint64 size; /// Size of the model on disk. + QString digest; /// The model's digest. + ModelDetails details; /// The model's details. +}; +BOOST_DESCRIBE_STRUCT(Model, (), (model, modified_at, size, digest, details)) + + +/// The response class for the version endpoint. +struct VersionResponse { + QString version; /// The version of the Ollama server. +}; +BOOST_DESCRIBE_STRUCT(VersionResponse, (), (version)) + +/// Response class for the list models endpoint. +struct ModelsResponse { + std::vector models; /// List of models available locally. +}; +BOOST_DESCRIBE_STRUCT(ModelsResponse, (), (models)) + +} // namespace gpt4all::backend::ollama diff --git a/gpt4all-backend/src/ollama_client.cpp b/gpt4all-backend/src/ollama_client.cpp index f33d0cc4..66e91624 100644 --- a/gpt4all-backend/src/ollama_client.cpp +++ b/gpt4all-backend/src/ollama_client.cpp @@ -15,14 +15,29 @@ #include using namespace Qt::Literals::StringLiterals; +using namespace gpt4all::backend::ollama; namespace json = boost::json; namespace gpt4all::backend { -auto OllamaClient::getVersion() -> QCoro::Task> +template +auto OllamaClient::getSimple(const QString &endpoint) -> QCoro::Task> { - std::unique_ptr reply(m_nam.get(QNetworkRequest(m_baseUrl.resolved(u"/api/version"_s)))); + auto value = co_await getSimpleGeneric(endpoint); + if (value) + co_return boost::json::value_to(*value); + co_return std::unexpected(value.error()); +} + +template auto OllamaClient::getSimple(const QString &) -> QCoro::Task>; +template auto OllamaClient::getSimple(const QString &) -> QCoro::Task>; + +auto OllamaClient::getSimpleGeneric(const QString &endpoint) -> QCoro::Task> +{ + std::unique_ptr reply(m_nam.get( + QNetworkRequest(m_baseUrl.resolved(u"/api/%1"_s.arg(endpoint))) + )); if (reply->error()) co_return std::unexpected(reply.get()); @@ -36,7 +51,7 @@ auto OllamaClient::getVersion() -> QCoro::Task> p.write(chunk.data(), chunk.size()); } while (!reply->atEnd()); - co_return json::value_to(p.release()); + co_return p.release(); } catch (const std::exception &e) { co_return std::unexpected(ResponseError(e, std::current_exception())); }