mirror of
				https://github.com/nomic-ai/gpt4all.git
				synced 2025-11-04 07:55:24 +00:00 
			
		
		
		
	Use the token cache to infer greater n_past and reuse results (#3073)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
		@@ -33,6 +33,7 @@
 | 
			
		||||
#include <functional>
 | 
			
		||||
#include <limits>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <span>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
@@ -404,7 +405,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
 | 
			
		||||
 | 
			
		||||
    QString requestedDevice = MySettings::globalInstance()->device();
 | 
			
		||||
    int n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
 | 
			
		||||
    m_ctx.n_ctx = n_ctx;
 | 
			
		||||
    int ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo);
 | 
			
		||||
 | 
			
		||||
    std::string backend = "auto";
 | 
			
		||||
@@ -632,7 +632,6 @@ void ChatLLM::regenerateResponse()
 | 
			
		||||
    else
 | 
			
		||||
        m_ctx.n_past -= m_promptResponseTokens;
 | 
			
		||||
    m_ctx.n_past = std::max(0, m_ctx.n_past);
 | 
			
		||||
    m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
 | 
			
		||||
    m_promptResponseTokens = 0;
 | 
			
		||||
    m_promptTokens = 0;
 | 
			
		||||
    m_response = m_trimmedResponse = std::string();
 | 
			
		||||
@@ -1078,12 +1077,13 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
 | 
			
		||||
        stream << responseLogits;
 | 
			
		||||
    }
 | 
			
		||||
    stream << m_ctx.n_past;
 | 
			
		||||
    if (version >= 7) {
 | 
			
		||||
        stream << m_ctx.n_ctx;
 | 
			
		||||
    }
 | 
			
		||||
    stream << quint64(m_ctx.tokens.size());
 | 
			
		||||
    stream.writeRawData(reinterpret_cast<const char*>(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int));
 | 
			
		||||
    saveState();
 | 
			
		||||
    if (version >= 7) {
 | 
			
		||||
        stream << m_stateContextLength;
 | 
			
		||||
    }
 | 
			
		||||
    stream << quint64(m_stateInputTokens.size());
 | 
			
		||||
    stream.writeRawData(reinterpret_cast<const char *>(m_stateInputTokens.data()),
 | 
			
		||||
                        m_stateInputTokens.size() * sizeof(m_stateInputTokens[0]));
 | 
			
		||||
    QByteArray compressed = qCompress(m_state);
 | 
			
		||||
    stream << compressed;
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
@@ -1145,7 +1145,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
 | 
			
		||||
    if (version >= 7) {
 | 
			
		||||
        uint32_t n_ctx;
 | 
			
		||||
        stream >> n_ctx;
 | 
			
		||||
        if (!discardKV) m_ctx.n_ctx = n_ctx;
 | 
			
		||||
        if (!discardKV) m_stateContextLength = n_ctx;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (version < 9) {
 | 
			
		||||
@@ -1157,10 +1157,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
 | 
			
		||||
    quint64 tokensSize;
 | 
			
		||||
    stream >> tokensSize;
 | 
			
		||||
    if (!discardKV) {
 | 
			
		||||
        m_ctx.tokens.resize(tokensSize);
 | 
			
		||||
        stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
 | 
			
		||||
        m_stateInputTokens.resize(tokensSize);
 | 
			
		||||
        stream.readRawData(reinterpret_cast<char *>(m_stateInputTokens.data()), tokensSize * sizeof(m_stateInputTokens[0]));
 | 
			
		||||
    } else {
 | 
			
		||||
        stream.skipRawData(tokensSize * sizeof(int));
 | 
			
		||||
        stream.skipRawData(tokensSize * sizeof(m_stateInputTokens[0]));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (version >= 1) {
 | 
			
		||||
@@ -1202,13 +1202,16 @@ void ChatLLM::saveState()
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
    qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
 | 
			
		||||
#endif
 | 
			
		||||
    bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
 | 
			
		||||
    bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())},
 | 
			
		||||
                                             m_stateInputTokens);
 | 
			
		||||
    if (!ok) {
 | 
			
		||||
        // FIXME(jared): how badly does this situation break GPT4All?
 | 
			
		||||
        qWarning() << "ChatLLM failed to save LLModel state";
 | 
			
		||||
        m_state.clear();
 | 
			
		||||
        m_state.squeeze();
 | 
			
		||||
        m_stateContextLength = -1;
 | 
			
		||||
    }
 | 
			
		||||
    m_stateContextLength = m_llModelInfo.model->contextLength();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ChatLLM::restoreState()
 | 
			
		||||
@@ -1235,13 +1238,22 @@ void ChatLLM::restoreState()
 | 
			
		||||
    if (m_state.isEmpty())
 | 
			
		||||
        return;
 | 
			
		||||
 | 
			
		||||
    size_t bytesRead = m_llModelInfo.model->restoreState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
 | 
			
		||||
    if (bytesRead) {
 | 
			
		||||
        m_processedSystemPrompt = true;
 | 
			
		||||
        m_pristineLoadedState = true;
 | 
			
		||||
    } else {
 | 
			
		||||
        qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
 | 
			
		||||
    if (m_llModelInfo.model->contextLength() != m_stateContextLength) {
 | 
			
		||||
        qWarning() << "restoring state from text because of n_ctx mismatch (state"
 | 
			
		||||
                   << m_stateContextLength << "model" << m_llModelInfo.model->contextLength() << ")";
 | 
			
		||||
        m_restoreStateFromText = true;
 | 
			
		||||
    } else {
 | 
			
		||||
        size_t bytesRead = m_llModelInfo.model->restoreState(
 | 
			
		||||
            {reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())},
 | 
			
		||||
            m_stateInputTokens
 | 
			
		||||
        );
 | 
			
		||||
        if (!bytesRead) {
 | 
			
		||||
            qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
 | 
			
		||||
            m_restoreStateFromText = true;
 | 
			
		||||
        } else {
 | 
			
		||||
            m_processedSystemPrompt = true;
 | 
			
		||||
            m_pristineLoadedState = true;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // free local state copy unless unload is pending
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user