Implement repeat penalty for both llama and gptj in gui.

This commit is contained in:
Adam Treat
2023-04-25 08:38:29 -04:00
parent cd2e559db4
commit 8b1ddabe3e
9 changed files with 107 additions and 50 deletions

20
llm.cpp
View File

@@ -124,6 +124,7 @@ void LLMObject::regenerateResponse()
s_ctx.n_past = std::max(0, s_ctx.n_past);
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end());
s_ctx.tokens.erase(s_ctx.tokens.end() -= m_responseTokens, s_ctx.tokens.end());
m_responseTokens = 0;
m_responseLogits = 0;
m_response = std::string();
@@ -243,12 +244,20 @@ QList<QString> LLMObject::modelList() const
return list;
}
bool LLMObject::handleResponse(const std::string &response)
bool LLMObject::handleResponse(int32_t token, const std::string &response)
{
#if 0
printf("%s", response.c_str());
fflush(stdout);
#endif
// Save the token to our prompt ctxt
if (s_ctx.tokens.size() == s_ctx.n_ctx)
s_ctx.tokens.erase(s_ctx.tokens.begin());
s_ctx.tokens.push_back(token);
// m_responseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_responseTokens;
if (!response.empty()) {
m_response.append(response);
@@ -271,10 +280,15 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
QString instructPrompt = prompt_template.arg(prompt);
m_stopGenerating = false;
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1);
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, std::placeholders::_2);
emit responseStarted();
qint32 logitsBefore = s_ctx.logits.size();
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx, n_predict, top_k, top_p, temp, n_batch);
s_ctx.n_predict = n_predict;
s_ctx.top_k = top_k;
s_ctx.top_p = top_p;
s_ctx.temp = temp;
s_ctx.n_batch = n_batch;
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx);
m_responseLogits += s_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) {