diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index af958afd..f6740293 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -244,6 +244,7 @@ qt_add_executable(chat src/localdocsmodel.cpp src/localdocsmodel.h src/logger.cpp src/logger.h src/modellist.cpp src/modellist.h + src/mwhttpserver.cpp src/mwhttpserver.h src/mysettings.cpp src/mysettings.h src/network.cpp src/network.h src/server.cpp src/server.h diff --git a/gpt4all-chat/src/mwhttpserver.cpp b/gpt4all-chat/src/mwhttpserver.cpp new file mode 100644 index 00000000..920ab0d1 --- /dev/null +++ b/gpt4all-chat/src/mwhttpserver.cpp @@ -0,0 +1,15 @@ +#include + +#include "mwhttpserver.h" + + +namespace gpt4all::ui { + + +MwHttpServer::MwHttpServer() + : m_httpServer() + , m_tcpServer (new QTcpServer(&m_httpServer)) + {} + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/mwhttpserver.h b/gpt4all-chat/src/mwhttpserver.h new file mode 100644 index 00000000..9edbecdd --- /dev/null +++ b/gpt4all-chat/src/mwhttpserver.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +class QHttpServerResponse; +class QHttpServerRouterRule; +class QString; + + +namespace gpt4all::ui { + + +/// @brief QHttpServer wrapper with middleware support. +/// +/// This class wraps QHttpServer and provides addBeforeRequestHandler() to add middleware. +class MwHttpServer +{ + using BeforeRequestHandler = std::function(const QHttpServerRequest &)>; + +public: + explicit MwHttpServer(); + + bool bind() { return m_httpServer.bind(m_tcpServer); } + + void addBeforeRequestHandler(BeforeRequestHandler handler) + { m_beforeRequestHandlers.push_back(std::move(handler)); } + + template + void addAfterRequestHandler( + const typename QtPrivate::ContextTypeForFunctor::ContextType *context, Handler &&handler + ) { + return m_httpServer.addAfterRequestHandler(context, std::forward(handler)); + } + + template + QHttpServerRouterRule *route( + const QString &pathPattern, + QHttpServerRequest::Methods method, + std::function viewHandler + ); + + QTcpServer *tcpServer() { return m_tcpServer; } + +private: + QHttpServer m_httpServer; + QTcpServer *m_tcpServer; + std::vector m_beforeRequestHandlers; +}; + + +} // namespace gpt4all::ui + + +#include "mwhttpserver.inl" // IWYU pragma: export diff --git a/gpt4all-chat/src/mwhttpserver.inl b/gpt4all-chat/src/mwhttpserver.inl new file mode 100644 index 00000000..aff946c7 --- /dev/null +++ b/gpt4all-chat/src/mwhttpserver.inl @@ -0,0 +1,20 @@ +namespace gpt4all::ui { + + +template +QHttpServerRouterRule *MwHttpServer::route( + const QString &pathPattern, + QHttpServerRequest::Methods method, + std::function viewHandler +) { + auto wrapped = [this, vh = std::move(viewHandler)](Args ...args, const QHttpServerRequest &req) { + for (auto &handler : m_beforeRequestHandlers) + if (auto resp = handler(req)) + return *std::move(resp); + return vh(std::forward(args)..., req); + }; + return m_httpServer.route(pathPattern, method, std::move(wrapped)); +} + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/server.cpp b/gpt4all-chat/src/server.cpp index 5f04b6f7..db2c28f6 100644 --- a/gpt4all-chat/src/server.cpp +++ b/gpt4all-chat/src/server.cpp @@ -3,12 +3,14 @@ #include "chat.h" #include "chatmodel.h" #include "modellist.h" +#include "mwhttpserver.h" #include "mysettings.h" #include "utils.h" // IWYU pragma: keep #include #include +#include #include #include #include @@ -51,6 +53,7 @@ using namespace std::string_literals; using namespace Qt::Literals::StringLiterals; +using namespace gpt4all::ui; //#define DEBUG @@ -443,6 +446,8 @@ Server::Server(Chat *chat) connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection); } +Server::~Server() = default; + static QJsonObject requestFromJson(const QByteArray &request) { QJsonParseError err; @@ -455,17 +460,57 @@ static QJsonObject requestFromJson(const QByteArray &request) return document.object(); } +/// @brief Check if a host is safe to use to connect to the server. +/// +/// GPT4All's local server is not safe to expose to the internet, as it does not provide +/// any form of authentication. DNS rebind attacks bypass CORS and without additional host +/// header validation, malicious websites can access the server in client-side js. +/// +/// @param host The value of the "Host" header or ":authority" pseudo-header +/// @return true if the host is unsafe, false otherwise +static bool isHostUnsafe(const QString &host) +{ + QHostAddress addr; + if (addr.setAddress(host) && addr.protocol() == QAbstractSocket::IPv4Protocol) + return false; // ipv4 + + // ipv6 host is wrapped in square brackets + static const QRegularExpression ipv6Re(uR"(^\[(.+)\]$)"_s); + if (auto match = ipv6Re.match(host); match.hasMatch()) { + auto ipv6 = match.captured(1); + if (addr.setAddress(ipv6) && addr.protocol() == QAbstractSocket::IPv6Protocol) + return false; // ipv6 + } + + if (!host.contains('.')) + return false; // dotless hostname + + static const QStringList allowedTlds { u".local"_s, u".test"_s, u".internal"_s }; + for (auto &tld : allowedTlds) + if (host.endsWith(tld, Qt::CaseInsensitive)) + return false; // local TLD + + return true; // unsafe +} + void Server::start() { - m_server = std::make_unique(this); - auto *tcpServer = new QTcpServer(m_server.get()); + m_server = std::make_unique(); + + m_server->addBeforeRequestHandler([](const QHttpServerRequest &req) -> std::optional { + // this works for HTTP/1.1 "Host" header and HTTP/2 ":authority" pseudo-header + auto host = req.url().host(); + if (!host.isEmpty() && isHostUnsafe(host)) + return QHttpServerResponse(QHttpServerResponder::StatusCode::Forbidden); + return std::nullopt; + }); auto port = MySettings::globalInstance()->networkPort(); - if (!tcpServer->listen(QHostAddress::LocalHost, port)) { + if (!m_server->tcpServer()->listen(QHostAddress::LocalHost, port)) { qWarning() << "Server ERROR: Failed to listen on port" << port; return; } - if (!m_server->bind(tcpServer)) { + if (!m_server->bind()) { qWarning() << "Server ERROR: Failed to HTTP server to socket" << port; return; } @@ -490,7 +535,7 @@ void Server::start() } ); - m_server->route("/v1/models/", QHttpServerRequest::Method::Get, + m_server->route("/v1/models/", QHttpServerRequest::Method::Get, [](const QString &model, const QHttpServerRequest &) { if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); @@ -562,7 +607,7 @@ void Server::start() // Respond with code 405 to wrong HTTP methods: m_server->route("/v1/models", QHttpServerRequest::Method::Post, - [] { + [](const QHttpServerRequest &) { if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); return QHttpServerResponse( @@ -573,8 +618,8 @@ void Server::start() } ); - m_server->route("/v1/models/", QHttpServerRequest::Method::Post, - [](const QString &model) { + m_server->route("/v1/models/", QHttpServerRequest::Method::Post, + [](const QString &model, const QHttpServerRequest &) { (void)model; if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); @@ -587,7 +632,7 @@ void Server::start() ); m_server->route("/v1/completions", QHttpServerRequest::Method::Get, - [] { + [](const QHttpServerRequest &) { if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); return QHttpServerResponse( @@ -598,7 +643,7 @@ void Server::start() ); m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Get, - [] { + [](const QHttpServerRequest &) { if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); return QHttpServerResponse( diff --git a/gpt4all-chat/src/server.h b/gpt4all-chat/src/server.h index 465bf524..1aa6b186 100644 --- a/gpt4all-chat/src/server.h +++ b/gpt4all-chat/src/server.h @@ -4,7 +4,6 @@ #include "chatllm.h" #include "database.h" -#include #include #include #include @@ -18,6 +17,7 @@ class Chat; class ChatRequest; class CompletionRequest; +namespace gpt4all::ui { class MwHttpServer; } class Server : public ChatLLM @@ -26,7 +26,7 @@ class Server : public ChatLLM public: explicit Server(Chat *chat); - ~Server() override = default; + ~Server() override; public Q_SLOTS: void start(); @@ -44,7 +44,7 @@ private Q_SLOTS: private: Chat *m_chat; - std::unique_ptr m_server; + std::unique_ptr m_server; QList m_databaseResults; QList m_collections; };