When regenerating erase the previous response and prompt from the context.

This commit is contained in:
Adam Treat 2023-04-14 20:34:42 -04:00
parent aa836fa6d5
commit f8005cff45
4 changed files with 13 additions and 5 deletions

View File

@ -707,7 +707,9 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
std::cerr << "GPT-J ERROR: Failed to process prompt\n"; std::cerr << "GPT-J ERROR: Failed to process prompt\n";
return; return;
} }
// We pass a null string to see if the user has asked us to stop... // We pass a null string for each token to see if the user has asked us to stop...
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t)
if (!response("")) if (!response(""))
return; return;
ctx.n_past += batch.size(); ctx.n_past += batch.size();

4
gptj.h
View File

@ -15,8 +15,8 @@ public:
bool loadModel(const std::string &modelPath, std::istream &fin) override; bool loadModel(const std::string &modelPath, std::istream &fin) override;
bool isModelLoaded() const override; bool isModelLoaded() const override;
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response, void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f, PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
float temp = 0.9f, int32_t n_batch = 9) override; float temp = 0.0f, int32_t n_batch = 9) override;
private: private:
GPTJPrivate *d_ptr; GPTJPrivate *d_ptr;

View File

@ -19,6 +19,7 @@ static LLModel::PromptContext s_ctx;
LLMObject::LLMObject() LLMObject::LLMObject()
: QObject{nullptr} : QObject{nullptr}
, m_llmodel(new GPTJ) , m_llmodel(new GPTJ)
, m_responseTokens(0)
{ {
moveToThread(&m_llmThread); moveToThread(&m_llmThread);
connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel); connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel);
@ -64,6 +65,9 @@ bool LLMObject::isModelLoaded() const
void LLMObject::resetResponse() void LLMObject::resetResponse()
{ {
s_ctx.n_past -= m_responseTokens;
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseTokens, s_ctx.logits.end());
m_responseTokens = 0;
m_response = std::string(); m_response = std::string();
emit responseChanged(); emit responseChanged();
} }
@ -89,6 +93,7 @@ bool LLMObject::handleResponse(const std::string &response)
printf("%s", response.c_str()); printf("%s", response.c_str());
fflush(stdout); fflush(stdout);
#endif #endif
++m_responseTokens;
if (!response.empty()) { if (!response.empty()) {
m_response.append(response); m_response.append(response);
emit responseChanged(); emit responseChanged();

1
llm.h
View File

@ -41,6 +41,7 @@ private:
private: private:
LLModel *m_llmodel; LLModel *m_llmodel;
std::string m_response; std::string m_response;
quint32 m_responseTokens;
QString m_modelName; QString m_modelName;
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;