Use the token cache to infer greater n_past and reuse results (#3073)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-10-31 11:19:12 -04:00
committed by GitHub
parent 62cab695eb
commit f07e2e63df
15 changed files with 320 additions and 169 deletions

View File

@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
### Added
- Add ability to attach text, markdown, and rst files to chat ([#3135](https://github.com/nomic-ai/gpt4all/pull/3135))
- Add feature to minimize to system tray (by [@bgallois](https://github.com/bgallois) in ([#3109](https://github.com/nomic-ai/gpt4all/pull/3109))
- Basic cache for faster prefill when the input shares a prefix with previous context ([#3073](https://github.com/nomic-ai/gpt4all/pull/3073))
### Changed
- Implement Qt 6.8 compatibility ([#3121](https://github.com/nomic-ai/gpt4all/pull/3121))

View File

@@ -51,7 +51,6 @@ bool ChatAPI::loadModel(const std::string &modelPath, int n_ctx, int ngl)
void ChatAPI::setThreadCount(int32_t n_threads)
{
Q_UNUSED(n_threads);
qt_noop();
}
int32_t ChatAPI::threadCount() const
@@ -68,24 +67,6 @@ bool ChatAPI::isModelLoaded() const
return true;
}
// All three of the state virtual functions are handled custom inside of chatllm save/restore
size_t ChatAPI::stateSize() const
{
throw std::logic_error("not implemented");
}
size_t ChatAPI::saveState(std::span<uint8_t> dest) const
{
Q_UNUSED(dest);
throw std::logic_error("not implemented");
}
size_t ChatAPI::restoreState(std::span<const uint8_t> src)
{
Q_UNUSED(src);
throw std::logic_error("not implemented");
}
void ChatAPI::prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,

View File

@@ -3,7 +3,7 @@
#include <gpt4all-backend/llmodel.h>
#include <QByteArray>
#include <QByteArray> // IWYU pragma: keep
#include <QNetworkReply>
#include <QObject>
#include <QString>
@@ -13,6 +13,8 @@
#include <cstddef>
#include <cstdint>
#include <functional>
#include <optional>
#include <span>
#include <stdexcept>
#include <string>
#include <string_view>
@@ -63,9 +65,15 @@ public:
bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
size_t stateSize() const override;
size_t saveState(std::span<uint8_t> dest) const override;
size_t restoreState(std::span<const uint8_t> src) override;
// All three of the state virtual functions are handled custom inside of chatllm save/restore
size_t stateSize() const override
{ throwNotImplemented(); }
size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const override
{ Q_UNUSED(stateOut); Q_UNUSED(inputTokensOut); throwNotImplemented(); }
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override
{ Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); }
void prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
@@ -88,6 +96,10 @@ public:
bool callResponse(int32_t token, const std::string &string);
[[noreturn]]
int32_t contextLength() const override
{ throwNotImplemented(); }
Q_SIGNALS:
void request(const QString &apiKey,
LLModel::PromptContext *ctx,
@@ -98,60 +110,69 @@ protected:
// them as they are only called from the default implementation of 'prompt' which we override and
// completely replace
[[noreturn]]
static void throwNotImplemented() { throw std::logic_error("not implemented"); }
[[noreturn]]
std::vector<Token> tokenize(std::string_view str, bool special) override
{
(void)str;
(void)special;
throw std::logic_error("not implemented");
}
{ Q_UNUSED(str); Q_UNUSED(special); throwNotImplemented(); }
[[noreturn]]
bool isSpecialToken(Token id) const override
{
(void)id;
throw std::logic_error("not implemented");
}
{ Q_UNUSED(id); throwNotImplemented(); }
[[noreturn]]
std::string tokenToString(Token id) const override
{
(void)id;
throw std::logic_error("not implemented");
}
{ Q_UNUSED(id); throwNotImplemented(); }
[[noreturn]]
void initSampler(PromptContext &ctx) override
{
(void)ctx;
throw std::logic_error("not implemented");
}
{ Q_UNUSED(ctx); throwNotImplemented(); }
Token sampleToken() const override { throw std::logic_error("not implemented"); }
[[noreturn]]
Token sampleToken() const override
{ throwNotImplemented(); }
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override
{
(void)ctx;
(void)tokens;
throw std::logic_error("not implemented");
}
[[noreturn]]
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override
{ Q_UNUSED(ctx); Q_UNUSED(tokens); throwNotImplemented(); }
[[noreturn]]
void shiftContext(PromptContext &promptCtx) override
{
(void)promptCtx;
throw std::logic_error("not implemented");
}
{ Q_UNUSED(promptCtx); throwNotImplemented(); }
int32_t contextLength() const override
{
throw std::logic_error("not implemented");
}
[[noreturn]]
int32_t inputLength() const override
{ throwNotImplemented(); }
[[noreturn]]
void setTokenizeInputPosition(int32_t pos) override
{ Q_UNUSED(pos); throwNotImplemented(); }
[[noreturn]]
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator override
{ Q_UNUSED(ctx); Q_UNUSED(input); throwNotImplemented(); }
[[noreturn]]
void setModelInputPosition(PromptContext &ctx, int32_t pos) override
{ Q_UNUSED(ctx); Q_UNUSED(pos); throwNotImplemented(); }
[[noreturn]]
void appendInputToken(PromptContext &ctx, Token tok) override
{ Q_UNUSED(ctx); Q_UNUSED(tok); throwNotImplemented(); }
[[noreturn]]
const std::vector<Token> &endTokens() const override
{
throw std::logic_error("not implemented");
}
{ throwNotImplemented(); }
[[noreturn]]
bool shouldAddBOS() const override
{
throw std::logic_error("not implemented");
}
{ throwNotImplemented(); }
[[noreturn]]
std::span<const Token> inputTokens() const override
{ throwNotImplemented(); }
private:
std::function<bool(int32_t, const std::string&)> m_responseCallback;

View File

@@ -33,6 +33,7 @@
#include <functional>
#include <limits>
#include <optional>
#include <span>
#include <string_view>
#include <utility>
#include <vector>
@@ -404,7 +405,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
QString requestedDevice = MySettings::globalInstance()->device();
int n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
m_ctx.n_ctx = n_ctx;
int ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo);
std::string backend = "auto";
@@ -632,7 +632,6 @@ void ChatLLM::regenerateResponse()
else
m_ctx.n_past -= m_promptResponseTokens;
m_ctx.n_past = std::max(0, m_ctx.n_past);
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
m_promptResponseTokens = 0;
m_promptTokens = 0;
m_response = m_trimmedResponse = std::string();
@@ -1078,12 +1077,13 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
stream << responseLogits;
}
stream << m_ctx.n_past;
if (version >= 7) {
stream << m_ctx.n_ctx;
}
stream << quint64(m_ctx.tokens.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int));
saveState();
if (version >= 7) {
stream << m_stateContextLength;
}
stream << quint64(m_stateInputTokens.size());
stream.writeRawData(reinterpret_cast<const char *>(m_stateInputTokens.data()),
m_stateInputTokens.size() * sizeof(m_stateInputTokens[0]));
QByteArray compressed = qCompress(m_state);
stream << compressed;
#if defined(DEBUG)
@@ -1145,7 +1145,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
if (version >= 7) {
uint32_t n_ctx;
stream >> n_ctx;
if (!discardKV) m_ctx.n_ctx = n_ctx;
if (!discardKV) m_stateContextLength = n_ctx;
}
if (version < 9) {
@@ -1157,10 +1157,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
quint64 tokensSize;
stream >> tokensSize;
if (!discardKV) {
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
m_stateInputTokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char *>(m_stateInputTokens.data()), tokensSize * sizeof(m_stateInputTokens[0]));
} else {
stream.skipRawData(tokensSize * sizeof(int));
stream.skipRawData(tokensSize * sizeof(m_stateInputTokens[0]));
}
if (version >= 1) {
@@ -1202,13 +1202,16 @@ void ChatLLM::saveState()
#if defined(DEBUG)
qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())},
m_stateInputTokens);
if (!ok) {
// FIXME(jared): how badly does this situation break GPT4All?
qWarning() << "ChatLLM failed to save LLModel state";
m_state.clear();
m_state.squeeze();
m_stateContextLength = -1;
}
m_stateContextLength = m_llModelInfo.model->contextLength();
}
void ChatLLM::restoreState()
@@ -1235,13 +1238,22 @@ void ChatLLM::restoreState()
if (m_state.isEmpty())
return;
size_t bytesRead = m_llModelInfo.model->restoreState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
if (bytesRead) {
m_processedSystemPrompt = true;
m_pristineLoadedState = true;
} else {
qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
if (m_llModelInfo.model->contextLength() != m_stateContextLength) {
qWarning() << "restoring state from text because of n_ctx mismatch (state"
<< m_stateContextLength << "model" << m_llModelInfo.model->contextLength() << ")";
m_restoreStateFromText = true;
} else {
size_t bytesRead = m_llModelInfo.model->restoreState(
{reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())},
m_stateInputTokens
);
if (!bytesRead) {
qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
m_restoreStateFromText = true;
} else {
m_processedSystemPrompt = true;
m_pristineLoadedState = true;
}
}
// free local state copy unless unload is pending

View File

@@ -9,7 +9,7 @@
#include <QByteArray>
#include <QElapsedTimer>
#include <QFileInfo>
#include <QList>
#include <QList> // IWYU pragma: keep
#include <QObject>
#include <QPointer>
#include <QString>
@@ -22,6 +22,7 @@
#include <memory>
#include <optional>
#include <string>
#include <vector>
using namespace Qt::Literals::StringLiterals;
@@ -277,6 +278,8 @@ private:
ModelInfo m_modelInfo;
TokenTimer *m_timer;
QByteArray m_state;
std::vector<LLModel::Token> m_stateInputTokens;
int32_t m_stateContextLength = -1;
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded;