Move the promptCallback to own function.

This commit is contained in:
Adam Treat 2023-04-27 11:08:15 -04:00
parent 0e9f85bcda
commit ba4b28fcd5
9 changed files with 81 additions and 47 deletions

36
llm.cpp
View File

@ -38,7 +38,7 @@ static QString modelFilePath(const QString &modelName)
LLMObject::LLMObject() LLMObject::LLMObject()
: QObject{nullptr} : QObject{nullptr}
, m_llmodel(nullptr) , m_llmodel(nullptr)
, m_responseTokens(0) , m_promptResponseTokens(0)
, m_responseLogits(0) , m_responseLogits(0)
, m_isRecalc(false) , m_isRecalc(false)
{ {
@ -133,12 +133,12 @@ bool LLMObject::isModelLoaded() const
void LLMObject::regenerateResponse() void LLMObject::regenerateResponse()
{ {
s_ctx.n_past -= m_responseTokens; s_ctx.n_past -= m_promptResponseTokens;
s_ctx.n_past = std::max(0, s_ctx.n_past); s_ctx.n_past = std::max(0, s_ctx.n_past);
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove? // FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end()); s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end());
s_ctx.tokens.erase(s_ctx.tokens.end() -= m_responseTokens, s_ctx.tokens.end()); s_ctx.tokens.erase(s_ctx.tokens.end() -= m_promptResponseTokens, s_ctx.tokens.end());
m_responseTokens = 0; m_promptResponseTokens = 0;
m_responseLogits = 0; m_responseLogits = 0;
m_response = std::string(); m_response = std::string();
emit responseChanged(); emit responseChanged();
@ -146,7 +146,7 @@ void LLMObject::regenerateResponse()
void LLMObject::resetResponse() void LLMObject::resetResponse()
{ {
m_responseTokens = 0; m_promptResponseTokens = 0;
m_responseLogits = 0; m_responseLogits = 0;
m_response = std::string(); m_response = std::string();
emit responseChanged(); emit responseChanged();
@ -263,6 +263,18 @@ QList<QString> LLMObject::modelList() const
return list; return list;
} }
bool LLMObject::handlePrompt(int32_t token)
{
if (s_ctx.tokens.size() == s_ctx.n_ctx)
s_ctx.tokens.erase(s_ctx.tokens.begin());
s_ctx.tokens.push_back(token);
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_promptResponseTokens;
return !m_stopGenerating;
}
bool LLMObject::handleResponse(int32_t token, const std::string &response) bool LLMObject::handleResponse(int32_t token, const std::string &response)
{ {
#if 0 #if 0
@ -282,13 +294,12 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response)
s_ctx.tokens.erase(s_ctx.tokens.begin()); s_ctx.tokens.erase(s_ctx.tokens.begin());
s_ctx.tokens.push_back(token); s_ctx.tokens.push_back(token);
// m_responseTokens and m_responseLogits are related to last prompt/response not // m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt // the entire context window which we can reset on regenerate prompt
++m_responseTokens; ++m_promptResponseTokens;
if (!response.empty()) { Q_ASSERT(!response.empty());
m_response.append(response); m_response.append(response);
emit responseChanged(); emit responseChanged();
}
// Stop generation if we encounter prompt or response tokens // Stop generation if we encounter prompt or response tokens
QString r = QString::fromStdString(m_response); QString r = QString::fromStdString(m_response);
@ -315,6 +326,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
QString instructPrompt = prompt_template.arg(prompt); QString instructPrompt = prompt_template.arg(prompt);
m_stopGenerating = false; m_stopGenerating = false;
auto promptFunc = std::bind(&LLMObject::handlePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1,
std::placeholders::_2); std::placeholders::_2);
auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1); auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1);
@ -327,7 +339,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
s_ctx.n_batch = n_batch; s_ctx.n_batch = n_batch;
s_ctx.repeat_penalty = repeat_penalty; s_ctx.repeat_penalty = repeat_penalty;
s_ctx.repeat_last_n = repeat_penalty_tokens; s_ctx.repeat_last_n = repeat_penalty_tokens;
m_llmodel->prompt(instructPrompt.toStdString(), responseFunc, recalcFunc, s_ctx); m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, s_ctx);
m_responseLogits += s_ctx.logits.size() - logitsBefore; m_responseLogits += s_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response); std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) { if (trimmed != m_response) {

3
llm.h
View File

@ -58,13 +58,14 @@ Q_SIGNALS:
private: private:
void resetContextPrivate(); void resetContextPrivate();
bool loadModelPrivate(const QString &modelName); bool loadModelPrivate(const QString &modelName);
bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response); bool handleResponse(int32_t token, const std::string &response);
bool handleRecalculate(bool isRecalc); bool handleRecalculate(bool isRecalc);
private: private:
LLModel *m_llmodel; LLModel *m_llmodel;
std::string m_response; std::string m_response;
quint32 m_responseTokens; quint32 m_promptResponseTokens;
quint32 m_responseLogits; quint32 m_responseLogits;
QString m_modelName; QString m_modelName;
QThread m_llmThread; QThread m_llmThread;

View File

@ -686,8 +686,9 @@ bool GPTJ::isModelLoaded() const
} }
void GPTJ::prompt(const std::string &prompt, void GPTJ::prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response, std::function<bool(int32_t)> promptCallback,
std::function<bool(bool)> recalculate, std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx) { PromptContext &promptCtx) {
if (!isModelLoaded()) { if (!isModelLoaded()) {
@ -708,7 +709,7 @@ void GPTJ::prompt(const std::string &prompt,
promptCtx.n_ctx = d_ptr->model.hparams.n_ctx; promptCtx.n_ctx = d_ptr->model.hparams.n_ctx;
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
response(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
std::cerr << "GPT-J ERROR: The prompt is" << embd_inp.size() << std::cerr << "GPT-J ERROR: The prompt is" << embd_inp.size() <<
"tokens and the context window is" << promptCtx.n_ctx << "!\n"; "tokens and the context window is" << promptCtx.n_ctx << "!\n";
return; return;
@ -741,7 +742,7 @@ void GPTJ::prompt(const std::string &prompt,
std::cerr << "GPTJ: reached the end of the context window so resizing\n"; std::cerr << "GPTJ: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size(); promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate); recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
} }
@ -750,10 +751,10 @@ void GPTJ::prompt(const std::string &prompt,
std::cerr << "GPT-J ERROR: Failed to process prompt\n"; std::cerr << "GPT-J ERROR: Failed to process prompt\n";
return; return;
} }
// We pass a null string for each token to see if the user has asked us to stop...
size_t tokens = batch_end - i; size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t) for (size_t t = 0; t < tokens; ++t)
if (!response(batch.at(t), "")) if (!promptCallback(batch.at(t)))
return; return;
promptCtx.n_past += batch.size(); promptCtx.n_past += batch.size();
i = batch_end; i = batch_end;
@ -790,8 +791,8 @@ void GPTJ::prompt(const std::string &prompt,
std::cerr << "GPTJ: reached the end of the context window so resizing\n"; std::cerr << "GPTJ: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size(); promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate); recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
} }
const int64_t t_start_predict_us = ggml_time_us(); const int64_t t_start_predict_us = ggml_time_us();
@ -805,7 +806,7 @@ void GPTJ::prompt(const std::string &prompt,
promptCtx.n_past += 1; promptCtx.n_past += 1;
// display text // display text
++totalPredictions; ++totalPredictions;
if (id == 50256 /*end of text*/ || !response(id, d_ptr->vocab.id_to_token[id])) if (id == 50256 /*end of text*/ || !responseCallback(id, d_ptr->vocab.id_to_token[id]))
goto stop_generating; goto stop_generating;
} }

View File

@ -16,8 +16,9 @@ public:
bool loadModel(const std::string &modelPath, std::istream &fin) override; bool loadModel(const std::string &modelPath, std::istream &fin) override;
bool isModelLoaded() const override; bool isModelLoaded() const override;
void prompt(const std::string &prompt, void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response, std::function<bool(int32_t)> promptCallback,
std::function<bool(bool)> recalculate, std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &ctx) override; PromptContext &ctx) override;
void setThreadCount(int32_t n_threads) override; void setThreadCount(int32_t n_threads) override;
int32_t threadCount() override; int32_t threadCount() override;

View File

@ -80,8 +80,9 @@ bool LLamaModel::isModelLoaded() const
} }
void LLamaModel::prompt(const std::string &prompt, void LLamaModel::prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response, std::function<bool(int32_t)> promptCallback,
std::function<bool(bool)> recalculate, std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx) { PromptContext &promptCtx) {
if (!isModelLoaded()) { if (!isModelLoaded()) {
@ -102,7 +103,7 @@ void LLamaModel::prompt(const std::string &prompt,
promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx); promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx);
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
response(-1, "The prompt size exceeds the context window size and cannot be processed."); responseCallback(-1, "The prompt size exceeds the context window size and cannot be processed.");
std::cerr << "LLAMA ERROR: The prompt is" << embd_inp.size() << std::cerr << "LLAMA ERROR: The prompt is" << embd_inp.size() <<
"tokens and the context window is" << promptCtx.n_ctx << "!\n"; "tokens and the context window is" << promptCtx.n_ctx << "!\n";
return; return;
@ -128,7 +129,7 @@ void LLamaModel::prompt(const std::string &prompt,
std::cerr << "LLAMA: reached the end of the context window so resizing\n"; std::cerr << "LLAMA: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size(); promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate); recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
} }
@ -137,10 +138,9 @@ void LLamaModel::prompt(const std::string &prompt,
return; return;
} }
// We pass a null string for each token to see if the user has asked us to stop...
size_t tokens = batch_end - i; size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t) for (size_t t = 0; t < tokens; ++t)
if (!response(batch.at(t), "")) if (!promptCallback(batch.at(t)))
return; return;
promptCtx.n_past += batch.size(); promptCtx.n_past += batch.size();
i = batch_end; i = batch_end;
@ -162,8 +162,8 @@ void LLamaModel::prompt(const std::string &prompt,
std::cerr << "LLAMA: reached the end of the context window so resizing\n"; std::cerr << "LLAMA: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size(); promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate); recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
} }
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) { if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
@ -174,7 +174,7 @@ void LLamaModel::prompt(const std::string &prompt,
promptCtx.n_past += 1; promptCtx.n_past += 1;
// display text // display text
++totalPredictions; ++totalPredictions;
if (id == llama_token_eos() || !response(id, llama_token_to_str(d_ptr->ctx, id))) if (id == llama_token_eos() || !responseCallback(id, llama_token_to_str(d_ptr->ctx, id)))
return; return;
} }
} }

View File

@ -16,8 +16,9 @@ public:
bool loadModel(const std::string &modelPath, std::istream &fin) override; bool loadModel(const std::string &modelPath, std::istream &fin) override;
bool isModelLoaded() const override; bool isModelLoaded() const override;
void prompt(const std::string &prompt, void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response, std::function<bool(int32_t)> promptCallback,
std::function<bool(bool)> recalculate, std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &ctx) override; PromptContext &ctx) override;
void setThreadCount(int32_t n_threads) override; void setThreadCount(int32_t n_threads) override;
int32_t threadCount() override; int32_t threadCount() override;

View File

@ -29,8 +29,9 @@ public:
// window // window
}; };
virtual void prompt(const std::string &prompt, virtual void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response, std::function<bool(int32_t)> promptCallback,
std::function<bool(bool)> recalculate, std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &ctx) = 0; PromptContext &ctx) = 0;
virtual void setThreadCount(int32_t n_threads) {} virtual void setThreadCount(int32_t n_threads) {}
virtual int32_t threadCount() { return 1; } virtual int32_t threadCount() { return 1; }

View File

@ -49,6 +49,11 @@ bool llmodel_isModelLoaded(llmodel_model model)
} }
// Wrapper functions for the C callbacks // Wrapper functions for the C callbacks
bool prompt_wrapper(int32_t token_id, void *user_data) {
llmodel_prompt_callback callback = reinterpret_cast<llmodel_prompt_callback>(user_data);
return callback(token_id);
}
bool response_wrapper(int32_t token_id, const std::string &response, void *user_data) { bool response_wrapper(int32_t token_id, const std::string &response, void *user_data) {
llmodel_response_callback callback = reinterpret_cast<llmodel_response_callback>(user_data); llmodel_response_callback callback = reinterpret_cast<llmodel_response_callback>(user_data);
return callback(token_id, response.c_str()); return callback(token_id, response.c_str());
@ -60,17 +65,20 @@ bool recalculate_wrapper(bool is_recalculating, void *user_data) {
} }
void llmodel_prompt(llmodel_model model, const char *prompt, void llmodel_prompt(llmodel_model model, const char *prompt,
llmodel_response_callback response, llmodel_response_callback prompt_callback,
llmodel_recalculate_callback recalculate, llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
llmodel_prompt_context *ctx) llmodel_prompt_context *ctx)
{ {
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model); LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
// Create std::function wrappers that call the C function pointers // Create std::function wrappers that call the C function pointers
std::function<bool(int32_t)> prompt_func =
std::bind(&prompt_wrapper, std::placeholders::_1, reinterpret_cast<void*>(prompt_callback));
std::function<bool(int32_t, const std::string&)> response_func = std::function<bool(int32_t, const std::string&)> response_func =
std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast<void*>(response)); std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast<void*>(response_callback));
std::function<bool(bool)> recalc_func = std::function<bool(bool)> recalc_func =
std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast<void*>(recalculate)); std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast<void*>(recalculate_callback));
// Copy the C prompt context // Copy the C prompt context
wrapper->promptContext.n_past = ctx->n_past; wrapper->promptContext.n_past = ctx->n_past;
@ -85,7 +93,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
wrapper->promptContext.contextErase = ctx->context_erase; wrapper->promptContext.contextErase = ctx->context_erase;
// Call the C++ prompt method // Call the C++ prompt method
wrapper->llModel->prompt(prompt, response_func, recalc_func, wrapper->promptContext); wrapper->llModel->prompt(prompt, prompt_func, response_func, recalc_func, wrapper->promptContext);
// Update the C context by giving access to the wrappers raw pointers to std::vector data // Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies // which involves no copies

View File

@ -37,10 +37,17 @@ typedef struct {
float context_erase; // percent of context to erase if we exceed the context window float context_erase; // percent of context to erase if we exceed the context window
} llmodel_prompt_context; } llmodel_prompt_context;
/**
* Callback type for prompt processing.
* @param token_id The token id of the prompt.
* @return a bool indicating whether the model should keep processing.
*/
typedef bool (*llmodel_prompt_callback)(int32_t token_id);
/** /**
* Callback type for response. * Callback type for response.
* @param token_id The token id of the response. * @param token_id The token id of the response.
* @param response The response string. * @param response The response string. NOTE: a token_id of -1 indicates the string is an error string.
* @return a bool indicating whether the model should keep generating. * @return a bool indicating whether the model should keep generating.
*/ */
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response); typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
@ -95,13 +102,15 @@ bool llmodel_isModelLoaded(llmodel_model model);
* Generate a response using the model. * Generate a response using the model.
* @param model A pointer to the llmodel_model instance. * @param model A pointer to the llmodel_model instance.
* @param prompt A string representing the input prompt. * @param prompt A string representing the input prompt.
* @param response A callback function for handling the generated response. * @param prompt_callback A callback function for handling the processing of prompt.
* @param recalculate A callback function for handling recalculation requests. * @param response_callback A callback function for handling the generated response.
* @param recalculate_callback A callback function for handling recalculation requests.
* @param ctx A pointer to the llmodel_prompt_context structure. * @param ctx A pointer to the llmodel_prompt_context structure.
*/ */
void llmodel_prompt(llmodel_model model, const char *prompt, void llmodel_prompt(llmodel_model model, const char *prompt,
llmodel_response_callback response, llmodel_response_callback prompt_callback,
llmodel_recalculate_callback recalculate, llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
llmodel_prompt_context *ctx); llmodel_prompt_context *ctx);
/** /**