server: block server access via non-local domains

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2025-05-27 15:31:08 -04:00
parent b666d16db5
commit faec111ce7
6 changed files with 154 additions and 13 deletions

View File

@ -244,6 +244,7 @@ qt_add_executable(chat
src/localdocsmodel.cpp src/localdocsmodel.h
src/logger.cpp src/logger.h
src/modellist.cpp src/modellist.h
src/mwhttpserver.cpp src/mwhttpserver.h
src/mysettings.cpp src/mysettings.h
src/network.cpp src/network.h
src/server.cpp src/server.h

View File

@ -0,0 +1,15 @@
#include <QTcpServer>
#include "mwhttpserver.h"
namespace gpt4all::ui {
MwHttpServer::MwHttpServer()
: m_httpServer()
, m_tcpServer (new QTcpServer(&m_httpServer))
{}
} // namespace gpt4all::ui

View File

@ -0,0 +1,60 @@
#pragma once
#include <QHttpServer>
#include <QHttpServerRequest>
#include <functional>
#include <optional>
#include <utility>
#include <vector>
class QHttpServerResponse;
class QHttpServerRouterRule;
class QString;
namespace gpt4all::ui {
/// @brief QHttpServer wrapper with middleware support.
///
/// This class wraps QHttpServer and provides addBeforeRequestHandler() to add middleware.
class MwHttpServer
{
using BeforeRequestHandler = std::function<std::optional<QHttpServerResponse>(const QHttpServerRequest &)>;
public:
explicit MwHttpServer();
bool bind() { return m_httpServer.bind(m_tcpServer); }
void addBeforeRequestHandler(BeforeRequestHandler handler)
{ m_beforeRequestHandlers.push_back(std::move(handler)); }
template <typename Handler>
void addAfterRequestHandler(
const typename QtPrivate::ContextTypeForFunctor<Handler>::ContextType *context, Handler &&handler
) {
return m_httpServer.addAfterRequestHandler(context, std::forward<Handler>(handler));
}
template <typename... Args>
QHttpServerRouterRule *route(
const QString &pathPattern,
QHttpServerRequest::Methods method,
std::function<QHttpServerResponse(Args..., const QHttpServerRequest &)> viewHandler
);
QTcpServer *tcpServer() { return m_tcpServer; }
private:
QHttpServer m_httpServer;
QTcpServer *m_tcpServer;
std::vector<BeforeRequestHandler> m_beforeRequestHandlers;
};
} // namespace gpt4all::ui
#include "mwhttpserver.inl" // IWYU pragma: export

View File

@ -0,0 +1,20 @@
namespace gpt4all::ui {
template <typename... Args>
QHttpServerRouterRule *MwHttpServer::route(
const QString &pathPattern,
QHttpServerRequest::Methods method,
std::function<QHttpServerResponse(Args..., const QHttpServerRequest &)> viewHandler
) {
auto wrapped = [this, vh = std::move(viewHandler)](Args ...args, const QHttpServerRequest &req) {
for (auto &handler : m_beforeRequestHandlers)
if (auto resp = handler(req))
return *std::move(resp);
return vh(std::forward<Args>(args)..., req);
};
return m_httpServer.route(pathPattern, method, std::move(wrapped));
}
} // namespace gpt4all::ui

View File

@ -3,12 +3,14 @@
#include "chat.h"
#include "chatmodel.h"
#include "modellist.h"
#include "mwhttpserver.h"
#include "mysettings.h"
#include "utils.h" // IWYU pragma: keep
#include <fmt/format.h>
#include <gpt4all-backend/llmodel.h>
#include <QAbstractSocket>
#include <QByteArray>
#include <QCborArray>
#include <QCborMap>
@ -51,6 +53,7 @@
using namespace std::string_literals;
using namespace Qt::Literals::StringLiterals;
using namespace gpt4all::ui;
//#define DEBUG
@ -443,6 +446,8 @@ Server::Server(Chat *chat)
connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection);
}
Server::~Server() = default;
static QJsonObject requestFromJson(const QByteArray &request)
{
QJsonParseError err;
@ -455,17 +460,57 @@ static QJsonObject requestFromJson(const QByteArray &request)
return document.object();
}
/// @brief Check if a host is safe to use to connect to the server.
///
/// GPT4All's local server is not safe to expose to the internet, as it does not provide
/// any form of authentication. DNS rebind attacks bypass CORS and without additional host
/// header validation, malicious websites can access the server in client-side js.
///
/// @param host The value of the "Host" header or ":authority" pseudo-header
/// @return true if the host is unsafe, false otherwise
static bool isHostUnsafe(const QString &host)
{
QHostAddress addr;
if (addr.setAddress(host) && addr.protocol() == QAbstractSocket::IPv4Protocol)
return false; // ipv4
// ipv6 host is wrapped in square brackets
static const QRegularExpression ipv6Re(uR"(^\[(.+)\]$)"_s);
if (auto match = ipv6Re.match(host); match.hasMatch()) {
auto ipv6 = match.captured(1);
if (addr.setAddress(ipv6) && addr.protocol() == QAbstractSocket::IPv6Protocol)
return false; // ipv6
}
if (!host.contains('.'))
return false; // dotless hostname
static const QStringList allowedTlds { u".local"_s, u".test"_s, u".internal"_s };
for (auto &tld : allowedTlds)
if (host.endsWith(tld, Qt::CaseInsensitive))
return false; // local TLD
return true; // unsafe
}
void Server::start()
{
m_server = std::make_unique<QHttpServer>(this);
auto *tcpServer = new QTcpServer(m_server.get());
m_server = std::make_unique<MwHttpServer>();
m_server->addBeforeRequestHandler([](const QHttpServerRequest &req) -> std::optional<QHttpServerResponse> {
// this works for HTTP/1.1 "Host" header and HTTP/2 ":authority" pseudo-header
auto host = req.url().host();
if (!host.isEmpty() && isHostUnsafe(host))
return QHttpServerResponse(QHttpServerResponder::StatusCode::Forbidden);
return std::nullopt;
});
auto port = MySettings::globalInstance()->networkPort();
if (!tcpServer->listen(QHostAddress::LocalHost, port)) {
if (!m_server->tcpServer()->listen(QHostAddress::LocalHost, port)) {
qWarning() << "Server ERROR: Failed to listen on port" << port;
return;
}
if (!m_server->bind(tcpServer)) {
if (!m_server->bind()) {
qWarning() << "Server ERROR: Failed to HTTP server to socket" << port;
return;
}
@ -490,7 +535,7 @@ void Server::start()
}
);
m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Get,
m_server->route<const QString &>("/v1/models/<arg>", QHttpServerRequest::Method::Get,
[](const QString &model, const QHttpServerRequest &) {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
@ -562,7 +607,7 @@ void Server::start()
// Respond with code 405 to wrong HTTP methods:
m_server->route("/v1/models", QHttpServerRequest::Method::Post,
[] {
[](const QHttpServerRequest &) {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
return QHttpServerResponse(
@ -573,8 +618,8 @@ void Server::start()
}
);
m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Post,
[](const QString &model) {
m_server->route<const QString &>("/v1/models/<arg>", QHttpServerRequest::Method::Post,
[](const QString &model, const QHttpServerRequest &) {
(void)model;
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
@ -587,7 +632,7 @@ void Server::start()
);
m_server->route("/v1/completions", QHttpServerRequest::Method::Get,
[] {
[](const QHttpServerRequest &) {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
return QHttpServerResponse(
@ -598,7 +643,7 @@ void Server::start()
);
m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Get,
[] {
[](const QHttpServerRequest &) {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
return QHttpServerResponse(

View File

@ -4,7 +4,6 @@
#include "chatllm.h"
#include "database.h"
#include <QHttpServer>
#include <QHttpServerResponse>
#include <QJsonObject>
#include <QList>
@ -18,6 +17,7 @@
class Chat;
class ChatRequest;
class CompletionRequest;
namespace gpt4all::ui { class MwHttpServer; }
class Server : public ChatLLM
@ -26,7 +26,7 @@ class Server : public ChatLLM
public:
explicit Server(Chat *chat);
~Server() override = default;
~Server() override;
public Q_SLOTS:
void start();
@ -44,7 +44,7 @@ private Q_SLOTS:
private:
Chat *m_chat;
std::unique_ptr<QHttpServer> m_server;
std::unique_ptr<gpt4all::ui::MwHttpServer> m_server;
QList<ResultInfo> m_databaseResults;
QList<QString> m_collections;
};