From 1c5dd6710d19176d4de83477290bb8c8d554c805 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Fri, 14 Apr 2023 20:34:42 -0400 Subject: [PATCH] When regenerating erase the previous response and prompt from the context. --- gptj.cpp | 8 +++++--- gptj.h | 4 ++-- llm.cpp | 5 +++++ llm.h | 1 + 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/gptj.cpp b/gptj.cpp index aa7db13e..34aa16f9 100644 --- a/gptj.cpp +++ b/gptj.cpp @@ -707,9 +707,11 @@ void GPTJ::prompt(const std::string &prompt, std::function response, - PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f, - float temp = 0.9f, int32_t n_batch = 9) override; + PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f, + float temp = 0.0f, int32_t n_batch = 9) override; private: GPTJPrivate *d_ptr; diff --git a/llm.cpp b/llm.cpp index 38ea62e3..b44d9ca3 100644 --- a/llm.cpp +++ b/llm.cpp @@ -19,6 +19,7 @@ static LLModel::PromptContext s_ctx; LLMObject::LLMObject() : QObject{nullptr} , m_llmodel(new GPTJ) + , m_responseTokens(0) { moveToThread(&m_llmThread); connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel); @@ -64,6 +65,9 @@ bool LLMObject::isModelLoaded() const void LLMObject::resetResponse() { + s_ctx.n_past -= m_responseTokens; + s_ctx.logits.erase(s_ctx.logits.end() -= m_responseTokens, s_ctx.logits.end()); + m_responseTokens = 0; m_response = std::string(); emit responseChanged(); } @@ -89,6 +93,7 @@ bool LLMObject::handleResponse(const std::string &response) printf("%s", response.c_str()); fflush(stdout); #endif + ++m_responseTokens; if (!response.empty()) { m_response.append(response); emit responseChanged(); diff --git a/llm.h b/llm.h index 3740723d..d47ab148 100644 --- a/llm.h +++ b/llm.h @@ -41,6 +41,7 @@ private: private: LLModel *m_llmodel; std::string m_response; + quint32 m_responseTokens; QString m_modelName; QThread m_llmThread; std::atomic m_stopGenerating;