mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-12 13:21:58 +00:00
server: block server access via non-local domains
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
b666d16db5
commit
faec111ce7
@ -244,6 +244,7 @@ qt_add_executable(chat
|
|||||||
src/localdocsmodel.cpp src/localdocsmodel.h
|
src/localdocsmodel.cpp src/localdocsmodel.h
|
||||||
src/logger.cpp src/logger.h
|
src/logger.cpp src/logger.h
|
||||||
src/modellist.cpp src/modellist.h
|
src/modellist.cpp src/modellist.h
|
||||||
|
src/mwhttpserver.cpp src/mwhttpserver.h
|
||||||
src/mysettings.cpp src/mysettings.h
|
src/mysettings.cpp src/mysettings.h
|
||||||
src/network.cpp src/network.h
|
src/network.cpp src/network.h
|
||||||
src/server.cpp src/server.h
|
src/server.cpp src/server.h
|
||||||
|
15
gpt4all-chat/src/mwhttpserver.cpp
Normal file
15
gpt4all-chat/src/mwhttpserver.cpp
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#include <QTcpServer>
|
||||||
|
|
||||||
|
#include "mwhttpserver.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace gpt4all::ui {
|
||||||
|
|
||||||
|
|
||||||
|
MwHttpServer::MwHttpServer()
|
||||||
|
: m_httpServer()
|
||||||
|
, m_tcpServer (new QTcpServer(&m_httpServer))
|
||||||
|
{}
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gpt4all::ui
|
60
gpt4all-chat/src/mwhttpserver.h
Normal file
60
gpt4all-chat/src/mwhttpserver.h
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QHttpServer>
|
||||||
|
#include <QHttpServerRequest>
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <optional>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
class QHttpServerResponse;
|
||||||
|
class QHttpServerRouterRule;
|
||||||
|
class QString;
|
||||||
|
|
||||||
|
|
||||||
|
namespace gpt4all::ui {
|
||||||
|
|
||||||
|
|
||||||
|
/// @brief QHttpServer wrapper with middleware support.
|
||||||
|
///
|
||||||
|
/// This class wraps QHttpServer and provides addBeforeRequestHandler() to add middleware.
|
||||||
|
class MwHttpServer
|
||||||
|
{
|
||||||
|
using BeforeRequestHandler = std::function<std::optional<QHttpServerResponse>(const QHttpServerRequest &)>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit MwHttpServer();
|
||||||
|
|
||||||
|
bool bind() { return m_httpServer.bind(m_tcpServer); }
|
||||||
|
|
||||||
|
void addBeforeRequestHandler(BeforeRequestHandler handler)
|
||||||
|
{ m_beforeRequestHandlers.push_back(std::move(handler)); }
|
||||||
|
|
||||||
|
template <typename Handler>
|
||||||
|
void addAfterRequestHandler(
|
||||||
|
const typename QtPrivate::ContextTypeForFunctor<Handler>::ContextType *context, Handler &&handler
|
||||||
|
) {
|
||||||
|
return m_httpServer.addAfterRequestHandler(context, std::forward<Handler>(handler));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
QHttpServerRouterRule *route(
|
||||||
|
const QString &pathPattern,
|
||||||
|
QHttpServerRequest::Methods method,
|
||||||
|
std::function<QHttpServerResponse(Args..., const QHttpServerRequest &)> viewHandler
|
||||||
|
);
|
||||||
|
|
||||||
|
QTcpServer *tcpServer() { return m_tcpServer; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
QHttpServer m_httpServer;
|
||||||
|
QTcpServer *m_tcpServer;
|
||||||
|
std::vector<BeforeRequestHandler> m_beforeRequestHandlers;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gpt4all::ui
|
||||||
|
|
||||||
|
|
||||||
|
#include "mwhttpserver.inl" // IWYU pragma: export
|
20
gpt4all-chat/src/mwhttpserver.inl
Normal file
20
gpt4all-chat/src/mwhttpserver.inl
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
namespace gpt4all::ui {
|
||||||
|
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
QHttpServerRouterRule *MwHttpServer::route(
|
||||||
|
const QString &pathPattern,
|
||||||
|
QHttpServerRequest::Methods method,
|
||||||
|
std::function<QHttpServerResponse(Args..., const QHttpServerRequest &)> viewHandler
|
||||||
|
) {
|
||||||
|
auto wrapped = [this, vh = std::move(viewHandler)](Args ...args, const QHttpServerRequest &req) {
|
||||||
|
for (auto &handler : m_beforeRequestHandlers)
|
||||||
|
if (auto resp = handler(req))
|
||||||
|
return *std::move(resp);
|
||||||
|
return vh(std::forward<Args>(args)..., req);
|
||||||
|
};
|
||||||
|
return m_httpServer.route(pathPattern, method, std::move(wrapped));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gpt4all::ui
|
@ -3,12 +3,14 @@
|
|||||||
#include "chat.h"
|
#include "chat.h"
|
||||||
#include "chatmodel.h"
|
#include "chatmodel.h"
|
||||||
#include "modellist.h"
|
#include "modellist.h"
|
||||||
|
#include "mwhttpserver.h"
|
||||||
#include "mysettings.h"
|
#include "mysettings.h"
|
||||||
#include "utils.h" // IWYU pragma: keep
|
#include "utils.h" // IWYU pragma: keep
|
||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <gpt4all-backend/llmodel.h>
|
#include <gpt4all-backend/llmodel.h>
|
||||||
|
|
||||||
|
#include <QAbstractSocket>
|
||||||
#include <QByteArray>
|
#include <QByteArray>
|
||||||
#include <QCborArray>
|
#include <QCborArray>
|
||||||
#include <QCborMap>
|
#include <QCborMap>
|
||||||
@ -51,6 +53,7 @@
|
|||||||
|
|
||||||
using namespace std::string_literals;
|
using namespace std::string_literals;
|
||||||
using namespace Qt::Literals::StringLiterals;
|
using namespace Qt::Literals::StringLiterals;
|
||||||
|
using namespace gpt4all::ui;
|
||||||
|
|
||||||
//#define DEBUG
|
//#define DEBUG
|
||||||
|
|
||||||
@ -443,6 +446,8 @@ Server::Server(Chat *chat)
|
|||||||
connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection);
|
connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Server::~Server() = default;
|
||||||
|
|
||||||
static QJsonObject requestFromJson(const QByteArray &request)
|
static QJsonObject requestFromJson(const QByteArray &request)
|
||||||
{
|
{
|
||||||
QJsonParseError err;
|
QJsonParseError err;
|
||||||
@ -455,17 +460,57 @@ static QJsonObject requestFromJson(const QByteArray &request)
|
|||||||
return document.object();
|
return document.object();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// @brief Check if a host is safe to use to connect to the server.
|
||||||
|
///
|
||||||
|
/// GPT4All's local server is not safe to expose to the internet, as it does not provide
|
||||||
|
/// any form of authentication. DNS rebind attacks bypass CORS and without additional host
|
||||||
|
/// header validation, malicious websites can access the server in client-side js.
|
||||||
|
///
|
||||||
|
/// @param host The value of the "Host" header or ":authority" pseudo-header
|
||||||
|
/// @return true if the host is unsafe, false otherwise
|
||||||
|
static bool isHostUnsafe(const QString &host)
|
||||||
|
{
|
||||||
|
QHostAddress addr;
|
||||||
|
if (addr.setAddress(host) && addr.protocol() == QAbstractSocket::IPv4Protocol)
|
||||||
|
return false; // ipv4
|
||||||
|
|
||||||
|
// ipv6 host is wrapped in square brackets
|
||||||
|
static const QRegularExpression ipv6Re(uR"(^\[(.+)\]$)"_s);
|
||||||
|
if (auto match = ipv6Re.match(host); match.hasMatch()) {
|
||||||
|
auto ipv6 = match.captured(1);
|
||||||
|
if (addr.setAddress(ipv6) && addr.protocol() == QAbstractSocket::IPv6Protocol)
|
||||||
|
return false; // ipv6
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!host.contains('.'))
|
||||||
|
return false; // dotless hostname
|
||||||
|
|
||||||
|
static const QStringList allowedTlds { u".local"_s, u".test"_s, u".internal"_s };
|
||||||
|
for (auto &tld : allowedTlds)
|
||||||
|
if (host.endsWith(tld, Qt::CaseInsensitive))
|
||||||
|
return false; // local TLD
|
||||||
|
|
||||||
|
return true; // unsafe
|
||||||
|
}
|
||||||
|
|
||||||
void Server::start()
|
void Server::start()
|
||||||
{
|
{
|
||||||
m_server = std::make_unique<QHttpServer>(this);
|
m_server = std::make_unique<MwHttpServer>();
|
||||||
auto *tcpServer = new QTcpServer(m_server.get());
|
|
||||||
|
m_server->addBeforeRequestHandler([](const QHttpServerRequest &req) -> std::optional<QHttpServerResponse> {
|
||||||
|
// this works for HTTP/1.1 "Host" header and HTTP/2 ":authority" pseudo-header
|
||||||
|
auto host = req.url().host();
|
||||||
|
if (!host.isEmpty() && isHostUnsafe(host))
|
||||||
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Forbidden);
|
||||||
|
return std::nullopt;
|
||||||
|
});
|
||||||
|
|
||||||
auto port = MySettings::globalInstance()->networkPort();
|
auto port = MySettings::globalInstance()->networkPort();
|
||||||
if (!tcpServer->listen(QHostAddress::LocalHost, port)) {
|
if (!m_server->tcpServer()->listen(QHostAddress::LocalHost, port)) {
|
||||||
qWarning() << "Server ERROR: Failed to listen on port" << port;
|
qWarning() << "Server ERROR: Failed to listen on port" << port;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (!m_server->bind(tcpServer)) {
|
if (!m_server->bind()) {
|
||||||
qWarning() << "Server ERROR: Failed to HTTP server to socket" << port;
|
qWarning() << "Server ERROR: Failed to HTTP server to socket" << port;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -490,7 +535,7 @@ void Server::start()
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Get,
|
m_server->route<const QString &>("/v1/models/<arg>", QHttpServerRequest::Method::Get,
|
||||||
[](const QString &model, const QHttpServerRequest &) {
|
[](const QString &model, const QHttpServerRequest &) {
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
@ -562,7 +607,7 @@ void Server::start()
|
|||||||
|
|
||||||
// Respond with code 405 to wrong HTTP methods:
|
// Respond with code 405 to wrong HTTP methods:
|
||||||
m_server->route("/v1/models", QHttpServerRequest::Method::Post,
|
m_server->route("/v1/models", QHttpServerRequest::Method::Post,
|
||||||
[] {
|
[](const QHttpServerRequest &) {
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
return QHttpServerResponse(
|
return QHttpServerResponse(
|
||||||
@ -573,8 +618,8 @@ void Server::start()
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Post,
|
m_server->route<const QString &>("/v1/models/<arg>", QHttpServerRequest::Method::Post,
|
||||||
[](const QString &model) {
|
[](const QString &model, const QHttpServerRequest &) {
|
||||||
(void)model;
|
(void)model;
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
@ -587,7 +632,7 @@ void Server::start()
|
|||||||
);
|
);
|
||||||
|
|
||||||
m_server->route("/v1/completions", QHttpServerRequest::Method::Get,
|
m_server->route("/v1/completions", QHttpServerRequest::Method::Get,
|
||||||
[] {
|
[](const QHttpServerRequest &) {
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
return QHttpServerResponse(
|
return QHttpServerResponse(
|
||||||
@ -598,7 +643,7 @@ void Server::start()
|
|||||||
);
|
);
|
||||||
|
|
||||||
m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Get,
|
m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Get,
|
||||||
[] {
|
[](const QHttpServerRequest &) {
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
return QHttpServerResponse(
|
return QHttpServerResponse(
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
#include "chatllm.h"
|
#include "chatllm.h"
|
||||||
#include "database.h"
|
#include "database.h"
|
||||||
|
|
||||||
#include <QHttpServer>
|
|
||||||
#include <QHttpServerResponse>
|
#include <QHttpServerResponse>
|
||||||
#include <QJsonObject>
|
#include <QJsonObject>
|
||||||
#include <QList>
|
#include <QList>
|
||||||
@ -18,6 +17,7 @@
|
|||||||
class Chat;
|
class Chat;
|
||||||
class ChatRequest;
|
class ChatRequest;
|
||||||
class CompletionRequest;
|
class CompletionRequest;
|
||||||
|
namespace gpt4all::ui { class MwHttpServer; }
|
||||||
|
|
||||||
|
|
||||||
class Server : public ChatLLM
|
class Server : public ChatLLM
|
||||||
@ -26,7 +26,7 @@ class Server : public ChatLLM
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
explicit Server(Chat *chat);
|
explicit Server(Chat *chat);
|
||||||
~Server() override = default;
|
~Server() override;
|
||||||
|
|
||||||
public Q_SLOTS:
|
public Q_SLOTS:
|
||||||
void start();
|
void start();
|
||||||
@ -44,7 +44,7 @@ private Q_SLOTS:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Chat *m_chat;
|
Chat *m_chat;
|
||||||
std::unique_ptr<QHttpServer> m_server;
|
std::unique_ptr<gpt4all::ui::MwHttpServer> m_server;
|
||||||
QList<ResultInfo> m_databaseResults;
|
QList<ResultInfo> m_databaseResults;
|
||||||
QList<QString> m_collections;
|
QList<QString> m_collections;
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user