mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-02 00:57:09 +00:00
Use the token cache to infer greater n_past and reuse results (#3073)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
||||
### Added
|
||||
- Add ability to attach text, markdown, and rst files to chat ([#3135](https://github.com/nomic-ai/gpt4all/pull/3135))
|
||||
- Add feature to minimize to system tray (by [@bgallois](https://github.com/bgallois) in ([#3109](https://github.com/nomic-ai/gpt4all/pull/3109))
|
||||
- Basic cache for faster prefill when the input shares a prefix with previous context ([#3073](https://github.com/nomic-ai/gpt4all/pull/3073))
|
||||
|
||||
### Changed
|
||||
- Implement Qt 6.8 compatibility ([#3121](https://github.com/nomic-ai/gpt4all/pull/3121))
|
||||
|
@@ -51,7 +51,6 @@ bool ChatAPI::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
||||
void ChatAPI::setThreadCount(int32_t n_threads)
|
||||
{
|
||||
Q_UNUSED(n_threads);
|
||||
qt_noop();
|
||||
}
|
||||
|
||||
int32_t ChatAPI::threadCount() const
|
||||
@@ -68,24 +67,6 @@ bool ChatAPI::isModelLoaded() const
|
||||
return true;
|
||||
}
|
||||
|
||||
// All three of the state virtual functions are handled custom inside of chatllm save/restore
|
||||
size_t ChatAPI::stateSize() const
|
||||
{
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
size_t ChatAPI::saveState(std::span<uint8_t> dest) const
|
||||
{
|
||||
Q_UNUSED(dest);
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
size_t ChatAPI::restoreState(std::span<const uint8_t> src)
|
||||
{
|
||||
Q_UNUSED(src);
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
void ChatAPI::prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
|
@@ -3,7 +3,7 @@
|
||||
|
||||
#include <gpt4all-backend/llmodel.h>
|
||||
|
||||
#include <QByteArray>
|
||||
#include <QByteArray> // IWYU pragma: keep
|
||||
#include <QNetworkReply>
|
||||
#include <QObject>
|
||||
#include <QString>
|
||||
@@ -13,6 +13,8 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <span>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
@@ -63,9 +65,15 @@ public:
|
||||
bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(std::span<uint8_t> dest) const override;
|
||||
size_t restoreState(std::span<const uint8_t> src) override;
|
||||
|
||||
// All three of the state virtual functions are handled custom inside of chatllm save/restore
|
||||
size_t stateSize() const override
|
||||
{ throwNotImplemented(); }
|
||||
size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const override
|
||||
{ Q_UNUSED(stateOut); Q_UNUSED(inputTokensOut); throwNotImplemented(); }
|
||||
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override
|
||||
{ Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); }
|
||||
|
||||
void prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
@@ -88,6 +96,10 @@ public:
|
||||
|
||||
bool callResponse(int32_t token, const std::string &string);
|
||||
|
||||
[[noreturn]]
|
||||
int32_t contextLength() const override
|
||||
{ throwNotImplemented(); }
|
||||
|
||||
Q_SIGNALS:
|
||||
void request(const QString &apiKey,
|
||||
LLModel::PromptContext *ctx,
|
||||
@@ -98,60 +110,69 @@ protected:
|
||||
// them as they are only called from the default implementation of 'prompt' which we override and
|
||||
// completely replace
|
||||
|
||||
[[noreturn]]
|
||||
static void throwNotImplemented() { throw std::logic_error("not implemented"); }
|
||||
|
||||
[[noreturn]]
|
||||
std::vector<Token> tokenize(std::string_view str, bool special) override
|
||||
{
|
||||
(void)str;
|
||||
(void)special;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ Q_UNUSED(str); Q_UNUSED(special); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
bool isSpecialToken(Token id) const override
|
||||
{
|
||||
(void)id;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ Q_UNUSED(id); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
std::string tokenToString(Token id) const override
|
||||
{
|
||||
(void)id;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ Q_UNUSED(id); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
void initSampler(PromptContext &ctx) override
|
||||
{
|
||||
(void)ctx;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ Q_UNUSED(ctx); throwNotImplemented(); }
|
||||
|
||||
Token sampleToken() const override { throw std::logic_error("not implemented"); }
|
||||
[[noreturn]]
|
||||
Token sampleToken() const override
|
||||
{ throwNotImplemented(); }
|
||||
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override
|
||||
{
|
||||
(void)ctx;
|
||||
(void)tokens;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
[[noreturn]]
|
||||
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override
|
||||
{ Q_UNUSED(ctx); Q_UNUSED(tokens); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
void shiftContext(PromptContext &promptCtx) override
|
||||
{
|
||||
(void)promptCtx;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ Q_UNUSED(promptCtx); throwNotImplemented(); }
|
||||
|
||||
int32_t contextLength() const override
|
||||
{
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
[[noreturn]]
|
||||
int32_t inputLength() const override
|
||||
{ throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
void setTokenizeInputPosition(int32_t pos) override
|
||||
{ Q_UNUSED(pos); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
|
||||
-> std::vector<Token>::const_iterator override
|
||||
{ Q_UNUSED(ctx); Q_UNUSED(input); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
void setModelInputPosition(PromptContext &ctx, int32_t pos) override
|
||||
{ Q_UNUSED(ctx); Q_UNUSED(pos); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
void appendInputToken(PromptContext &ctx, Token tok) override
|
||||
{ Q_UNUSED(ctx); Q_UNUSED(tok); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
const std::vector<Token> &endTokens() const override
|
||||
{
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
bool shouldAddBOS() const override
|
||||
{
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
std::span<const Token> inputTokens() const override
|
||||
{ throwNotImplemented(); }
|
||||
|
||||
private:
|
||||
std::function<bool(int32_t, const std::string&)> m_responseCallback;
|
||||
|
@@ -33,6 +33,7 @@
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <span>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -404,7 +405,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
|
||||
|
||||
QString requestedDevice = MySettings::globalInstance()->device();
|
||||
int n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
|
||||
m_ctx.n_ctx = n_ctx;
|
||||
int ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo);
|
||||
|
||||
std::string backend = "auto";
|
||||
@@ -632,7 +632,6 @@ void ChatLLM::regenerateResponse()
|
||||
else
|
||||
m_ctx.n_past -= m_promptResponseTokens;
|
||||
m_ctx.n_past = std::max(0, m_ctx.n_past);
|
||||
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
|
||||
m_promptResponseTokens = 0;
|
||||
m_promptTokens = 0;
|
||||
m_response = m_trimmedResponse = std::string();
|
||||
@@ -1078,12 +1077,13 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
||||
stream << responseLogits;
|
||||
}
|
||||
stream << m_ctx.n_past;
|
||||
if (version >= 7) {
|
||||
stream << m_ctx.n_ctx;
|
||||
}
|
||||
stream << quint64(m_ctx.tokens.size());
|
||||
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int));
|
||||
saveState();
|
||||
if (version >= 7) {
|
||||
stream << m_stateContextLength;
|
||||
}
|
||||
stream << quint64(m_stateInputTokens.size());
|
||||
stream.writeRawData(reinterpret_cast<const char *>(m_stateInputTokens.data()),
|
||||
m_stateInputTokens.size() * sizeof(m_stateInputTokens[0]));
|
||||
QByteArray compressed = qCompress(m_state);
|
||||
stream << compressed;
|
||||
#if defined(DEBUG)
|
||||
@@ -1145,7 +1145,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
||||
if (version >= 7) {
|
||||
uint32_t n_ctx;
|
||||
stream >> n_ctx;
|
||||
if (!discardKV) m_ctx.n_ctx = n_ctx;
|
||||
if (!discardKV) m_stateContextLength = n_ctx;
|
||||
}
|
||||
|
||||
if (version < 9) {
|
||||
@@ -1157,10 +1157,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
||||
quint64 tokensSize;
|
||||
stream >> tokensSize;
|
||||
if (!discardKV) {
|
||||
m_ctx.tokens.resize(tokensSize);
|
||||
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
|
||||
m_stateInputTokens.resize(tokensSize);
|
||||
stream.readRawData(reinterpret_cast<char *>(m_stateInputTokens.data()), tokensSize * sizeof(m_stateInputTokens[0]));
|
||||
} else {
|
||||
stream.skipRawData(tokensSize * sizeof(int));
|
||||
stream.skipRawData(tokensSize * sizeof(m_stateInputTokens[0]));
|
||||
}
|
||||
|
||||
if (version >= 1) {
|
||||
@@ -1202,13 +1202,16 @@ void ChatLLM::saveState()
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
|
||||
#endif
|
||||
bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
|
||||
bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())},
|
||||
m_stateInputTokens);
|
||||
if (!ok) {
|
||||
// FIXME(jared): how badly does this situation break GPT4All?
|
||||
qWarning() << "ChatLLM failed to save LLModel state";
|
||||
m_state.clear();
|
||||
m_state.squeeze();
|
||||
m_stateContextLength = -1;
|
||||
}
|
||||
m_stateContextLength = m_llModelInfo.model->contextLength();
|
||||
}
|
||||
|
||||
void ChatLLM::restoreState()
|
||||
@@ -1235,13 +1238,22 @@ void ChatLLM::restoreState()
|
||||
if (m_state.isEmpty())
|
||||
return;
|
||||
|
||||
size_t bytesRead = m_llModelInfo.model->restoreState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
|
||||
if (bytesRead) {
|
||||
m_processedSystemPrompt = true;
|
||||
m_pristineLoadedState = true;
|
||||
} else {
|
||||
qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
|
||||
if (m_llModelInfo.model->contextLength() != m_stateContextLength) {
|
||||
qWarning() << "restoring state from text because of n_ctx mismatch (state"
|
||||
<< m_stateContextLength << "model" << m_llModelInfo.model->contextLength() << ")";
|
||||
m_restoreStateFromText = true;
|
||||
} else {
|
||||
size_t bytesRead = m_llModelInfo.model->restoreState(
|
||||
{reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())},
|
||||
m_stateInputTokens
|
||||
);
|
||||
if (!bytesRead) {
|
||||
qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
|
||||
m_restoreStateFromText = true;
|
||||
} else {
|
||||
m_processedSystemPrompt = true;
|
||||
m_pristineLoadedState = true;
|
||||
}
|
||||
}
|
||||
|
||||
// free local state copy unless unload is pending
|
||||
|
@@ -9,7 +9,7 @@
|
||||
#include <QByteArray>
|
||||
#include <QElapsedTimer>
|
||||
#include <QFileInfo>
|
||||
#include <QList>
|
||||
#include <QList> // IWYU pragma: keep
|
||||
#include <QObject>
|
||||
#include <QPointer>
|
||||
#include <QString>
|
||||
@@ -22,6 +22,7 @@
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace Qt::Literals::StringLiterals;
|
||||
|
||||
@@ -277,6 +278,8 @@ private:
|
||||
ModelInfo m_modelInfo;
|
||||
TokenTimer *m_timer;
|
||||
QByteArray m_state;
|
||||
std::vector<LLModel::Token> m_stateInputTokens;
|
||||
int32_t m_stateContextLength = -1;
|
||||
QThread m_llmThread;
|
||||
std::atomic<bool> m_stopGenerating;
|
||||
std::atomic<bool> m_shouldBeLoaded;
|
||||
|
Reference in New Issue
Block a user