mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-09 20:07:19 +00:00
Move the promptCallback to own function.
This commit is contained in:
parent
0e9f85bcda
commit
ba4b28fcd5
36
llm.cpp
36
llm.cpp
@ -38,7 +38,7 @@ static QString modelFilePath(const QString &modelName)
|
|||||||
LLMObject::LLMObject()
|
LLMObject::LLMObject()
|
||||||
: QObject{nullptr}
|
: QObject{nullptr}
|
||||||
, m_llmodel(nullptr)
|
, m_llmodel(nullptr)
|
||||||
, m_responseTokens(0)
|
, m_promptResponseTokens(0)
|
||||||
, m_responseLogits(0)
|
, m_responseLogits(0)
|
||||||
, m_isRecalc(false)
|
, m_isRecalc(false)
|
||||||
{
|
{
|
||||||
@ -133,12 +133,12 @@ bool LLMObject::isModelLoaded() const
|
|||||||
|
|
||||||
void LLMObject::regenerateResponse()
|
void LLMObject::regenerateResponse()
|
||||||
{
|
{
|
||||||
s_ctx.n_past -= m_responseTokens;
|
s_ctx.n_past -= m_promptResponseTokens;
|
||||||
s_ctx.n_past = std::max(0, s_ctx.n_past);
|
s_ctx.n_past = std::max(0, s_ctx.n_past);
|
||||||
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
|
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
|
||||||
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end());
|
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end());
|
||||||
s_ctx.tokens.erase(s_ctx.tokens.end() -= m_responseTokens, s_ctx.tokens.end());
|
s_ctx.tokens.erase(s_ctx.tokens.end() -= m_promptResponseTokens, s_ctx.tokens.end());
|
||||||
m_responseTokens = 0;
|
m_promptResponseTokens = 0;
|
||||||
m_responseLogits = 0;
|
m_responseLogits = 0;
|
||||||
m_response = std::string();
|
m_response = std::string();
|
||||||
emit responseChanged();
|
emit responseChanged();
|
||||||
@ -146,7 +146,7 @@ void LLMObject::regenerateResponse()
|
|||||||
|
|
||||||
void LLMObject::resetResponse()
|
void LLMObject::resetResponse()
|
||||||
{
|
{
|
||||||
m_responseTokens = 0;
|
m_promptResponseTokens = 0;
|
||||||
m_responseLogits = 0;
|
m_responseLogits = 0;
|
||||||
m_response = std::string();
|
m_response = std::string();
|
||||||
emit responseChanged();
|
emit responseChanged();
|
||||||
@ -263,6 +263,18 @@ QList<QString> LLMObject::modelList() const
|
|||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool LLMObject::handlePrompt(int32_t token)
|
||||||
|
{
|
||||||
|
if (s_ctx.tokens.size() == s_ctx.n_ctx)
|
||||||
|
s_ctx.tokens.erase(s_ctx.tokens.begin());
|
||||||
|
s_ctx.tokens.push_back(token);
|
||||||
|
|
||||||
|
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
|
||||||
|
// the entire context window which we can reset on regenerate prompt
|
||||||
|
++m_promptResponseTokens;
|
||||||
|
return !m_stopGenerating;
|
||||||
|
}
|
||||||
|
|
||||||
bool LLMObject::handleResponse(int32_t token, const std::string &response)
|
bool LLMObject::handleResponse(int32_t token, const std::string &response)
|
||||||
{
|
{
|
||||||
#if 0
|
#if 0
|
||||||
@ -282,13 +294,12 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response)
|
|||||||
s_ctx.tokens.erase(s_ctx.tokens.begin());
|
s_ctx.tokens.erase(s_ctx.tokens.begin());
|
||||||
s_ctx.tokens.push_back(token);
|
s_ctx.tokens.push_back(token);
|
||||||
|
|
||||||
// m_responseTokens and m_responseLogits are related to last prompt/response not
|
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
|
||||||
// the entire context window which we can reset on regenerate prompt
|
// the entire context window which we can reset on regenerate prompt
|
||||||
++m_responseTokens;
|
++m_promptResponseTokens;
|
||||||
if (!response.empty()) {
|
Q_ASSERT(!response.empty());
|
||||||
m_response.append(response);
|
m_response.append(response);
|
||||||
emit responseChanged();
|
emit responseChanged();
|
||||||
}
|
|
||||||
|
|
||||||
// Stop generation if we encounter prompt or response tokens
|
// Stop generation if we encounter prompt or response tokens
|
||||||
QString r = QString::fromStdString(m_response);
|
QString r = QString::fromStdString(m_response);
|
||||||
@ -315,6 +326,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
|
|||||||
QString instructPrompt = prompt_template.arg(prompt);
|
QString instructPrompt = prompt_template.arg(prompt);
|
||||||
|
|
||||||
m_stopGenerating = false;
|
m_stopGenerating = false;
|
||||||
|
auto promptFunc = std::bind(&LLMObject::handlePrompt, this, std::placeholders::_1);
|
||||||
auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1,
|
auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1,
|
||||||
std::placeholders::_2);
|
std::placeholders::_2);
|
||||||
auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1);
|
auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1);
|
||||||
@ -327,7 +339,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
|
|||||||
s_ctx.n_batch = n_batch;
|
s_ctx.n_batch = n_batch;
|
||||||
s_ctx.repeat_penalty = repeat_penalty;
|
s_ctx.repeat_penalty = repeat_penalty;
|
||||||
s_ctx.repeat_last_n = repeat_penalty_tokens;
|
s_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||||
m_llmodel->prompt(instructPrompt.toStdString(), responseFunc, recalcFunc, s_ctx);
|
m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, s_ctx);
|
||||||
m_responseLogits += s_ctx.logits.size() - logitsBefore;
|
m_responseLogits += s_ctx.logits.size() - logitsBefore;
|
||||||
std::string trimmed = trim_whitespace(m_response);
|
std::string trimmed = trim_whitespace(m_response);
|
||||||
if (trimmed != m_response) {
|
if (trimmed != m_response) {
|
||||||
|
3
llm.h
3
llm.h
@ -58,13 +58,14 @@ Q_SIGNALS:
|
|||||||
private:
|
private:
|
||||||
void resetContextPrivate();
|
void resetContextPrivate();
|
||||||
bool loadModelPrivate(const QString &modelName);
|
bool loadModelPrivate(const QString &modelName);
|
||||||
|
bool handlePrompt(int32_t token);
|
||||||
bool handleResponse(int32_t token, const std::string &response);
|
bool handleResponse(int32_t token, const std::string &response);
|
||||||
bool handleRecalculate(bool isRecalc);
|
bool handleRecalculate(bool isRecalc);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LLModel *m_llmodel;
|
LLModel *m_llmodel;
|
||||||
std::string m_response;
|
std::string m_response;
|
||||||
quint32 m_responseTokens;
|
quint32 m_promptResponseTokens;
|
||||||
quint32 m_responseLogits;
|
quint32 m_responseLogits;
|
||||||
QString m_modelName;
|
QString m_modelName;
|
||||||
QThread m_llmThread;
|
QThread m_llmThread;
|
||||||
|
@ -686,8 +686,9 @@ bool GPTJ::isModelLoaded() const
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GPTJ::prompt(const std::string &prompt,
|
void GPTJ::prompt(const std::string &prompt,
|
||||||
std::function<bool(int32_t, const std::string&)> response,
|
std::function<bool(int32_t)> promptCallback,
|
||||||
std::function<bool(bool)> recalculate,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
|
std::function<bool(bool)> recalculateCallback,
|
||||||
PromptContext &promptCtx) {
|
PromptContext &promptCtx) {
|
||||||
|
|
||||||
if (!isModelLoaded()) {
|
if (!isModelLoaded()) {
|
||||||
@ -708,7 +709,7 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
promptCtx.n_ctx = d_ptr->model.hparams.n_ctx;
|
promptCtx.n_ctx = d_ptr->model.hparams.n_ctx;
|
||||||
|
|
||||||
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
||||||
response(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
|
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
|
||||||
std::cerr << "GPT-J ERROR: The prompt is" << embd_inp.size() <<
|
std::cerr << "GPT-J ERROR: The prompt is" << embd_inp.size() <<
|
||||||
"tokens and the context window is" << promptCtx.n_ctx << "!\n";
|
"tokens and the context window is" << promptCtx.n_ctx << "!\n";
|
||||||
return;
|
return;
|
||||||
@ -741,7 +742,7 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
||||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||||
promptCtx.n_past = promptCtx.tokens.size();
|
promptCtx.n_past = promptCtx.tokens.size();
|
||||||
recalculateContext(promptCtx, recalculate);
|
recalculateContext(promptCtx, recalculateCallback);
|
||||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -750,10 +751,10 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// We pass a null string for each token to see if the user has asked us to stop...
|
|
||||||
size_t tokens = batch_end - i;
|
size_t tokens = batch_end - i;
|
||||||
for (size_t t = 0; t < tokens; ++t)
|
for (size_t t = 0; t < tokens; ++t)
|
||||||
if (!response(batch.at(t), ""))
|
if (!promptCallback(batch.at(t)))
|
||||||
return;
|
return;
|
||||||
promptCtx.n_past += batch.size();
|
promptCtx.n_past += batch.size();
|
||||||
i = batch_end;
|
i = batch_end;
|
||||||
@ -790,8 +791,8 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
||||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||||
promptCtx.n_past = promptCtx.tokens.size();
|
promptCtx.n_past = promptCtx.tokens.size();
|
||||||
recalculateContext(promptCtx, recalculate);
|
recalculateContext(promptCtx, recalculateCallback);
|
||||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t t_start_predict_us = ggml_time_us();
|
const int64_t t_start_predict_us = ggml_time_us();
|
||||||
@ -805,7 +806,7 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
promptCtx.n_past += 1;
|
promptCtx.n_past += 1;
|
||||||
// display text
|
// display text
|
||||||
++totalPredictions;
|
++totalPredictions;
|
||||||
if (id == 50256 /*end of text*/ || !response(id, d_ptr->vocab.id_to_token[id]))
|
if (id == 50256 /*end of text*/ || !responseCallback(id, d_ptr->vocab.id_to_token[id]))
|
||||||
goto stop_generating;
|
goto stop_generating;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,8 +16,9 @@ public:
|
|||||||
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||||
bool isModelLoaded() const override;
|
bool isModelLoaded() const override;
|
||||||
void prompt(const std::string &prompt,
|
void prompt(const std::string &prompt,
|
||||||
std::function<bool(int32_t, const std::string&)> response,
|
std::function<bool(int32_t)> promptCallback,
|
||||||
std::function<bool(bool)> recalculate,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
|
std::function<bool(bool)> recalculateCallback,
|
||||||
PromptContext &ctx) override;
|
PromptContext &ctx) override;
|
||||||
void setThreadCount(int32_t n_threads) override;
|
void setThreadCount(int32_t n_threads) override;
|
||||||
int32_t threadCount() override;
|
int32_t threadCount() override;
|
||||||
|
@ -80,8 +80,9 @@ bool LLamaModel::isModelLoaded() const
|
|||||||
}
|
}
|
||||||
|
|
||||||
void LLamaModel::prompt(const std::string &prompt,
|
void LLamaModel::prompt(const std::string &prompt,
|
||||||
std::function<bool(int32_t, const std::string&)> response,
|
std::function<bool(int32_t)> promptCallback,
|
||||||
std::function<bool(bool)> recalculate,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
|
std::function<bool(bool)> recalculateCallback,
|
||||||
PromptContext &promptCtx) {
|
PromptContext &promptCtx) {
|
||||||
|
|
||||||
if (!isModelLoaded()) {
|
if (!isModelLoaded()) {
|
||||||
@ -102,7 +103,7 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx);
|
promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx);
|
||||||
|
|
||||||
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
||||||
response(-1, "The prompt size exceeds the context window size and cannot be processed.");
|
responseCallback(-1, "The prompt size exceeds the context window size and cannot be processed.");
|
||||||
std::cerr << "LLAMA ERROR: The prompt is" << embd_inp.size() <<
|
std::cerr << "LLAMA ERROR: The prompt is" << embd_inp.size() <<
|
||||||
"tokens and the context window is" << promptCtx.n_ctx << "!\n";
|
"tokens and the context window is" << promptCtx.n_ctx << "!\n";
|
||||||
return;
|
return;
|
||||||
@ -128,7 +129,7 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
||||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||||
promptCtx.n_past = promptCtx.tokens.size();
|
promptCtx.n_past = promptCtx.tokens.size();
|
||||||
recalculateContext(promptCtx, recalculate);
|
recalculateContext(promptCtx, recalculateCallback);
|
||||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,10 +138,9 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We pass a null string for each token to see if the user has asked us to stop...
|
|
||||||
size_t tokens = batch_end - i;
|
size_t tokens = batch_end - i;
|
||||||
for (size_t t = 0; t < tokens; ++t)
|
for (size_t t = 0; t < tokens; ++t)
|
||||||
if (!response(batch.at(t), ""))
|
if (!promptCallback(batch.at(t)))
|
||||||
return;
|
return;
|
||||||
promptCtx.n_past += batch.size();
|
promptCtx.n_past += batch.size();
|
||||||
i = batch_end;
|
i = batch_end;
|
||||||
@ -162,8 +162,8 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
||||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||||
promptCtx.n_past = promptCtx.tokens.size();
|
promptCtx.n_past = promptCtx.tokens.size();
|
||||||
recalculateContext(promptCtx, recalculate);
|
recalculateContext(promptCtx, recalculateCallback);
|
||||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
|
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
|
||||||
@ -174,7 +174,7 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
promptCtx.n_past += 1;
|
promptCtx.n_past += 1;
|
||||||
// display text
|
// display text
|
||||||
++totalPredictions;
|
++totalPredictions;
|
||||||
if (id == llama_token_eos() || !response(id, llama_token_to_str(d_ptr->ctx, id)))
|
if (id == llama_token_eos() || !responseCallback(id, llama_token_to_str(d_ptr->ctx, id)))
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,8 +16,9 @@ public:
|
|||||||
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||||
bool isModelLoaded() const override;
|
bool isModelLoaded() const override;
|
||||||
void prompt(const std::string &prompt,
|
void prompt(const std::string &prompt,
|
||||||
std::function<bool(int32_t, const std::string&)> response,
|
std::function<bool(int32_t)> promptCallback,
|
||||||
std::function<bool(bool)> recalculate,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
|
std::function<bool(bool)> recalculateCallback,
|
||||||
PromptContext &ctx) override;
|
PromptContext &ctx) override;
|
||||||
void setThreadCount(int32_t n_threads) override;
|
void setThreadCount(int32_t n_threads) override;
|
||||||
int32_t threadCount() override;
|
int32_t threadCount() override;
|
||||||
|
@ -29,8 +29,9 @@ public:
|
|||||||
// window
|
// window
|
||||||
};
|
};
|
||||||
virtual void prompt(const std::string &prompt,
|
virtual void prompt(const std::string &prompt,
|
||||||
std::function<bool(int32_t, const std::string&)> response,
|
std::function<bool(int32_t)> promptCallback,
|
||||||
std::function<bool(bool)> recalculate,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
|
std::function<bool(bool)> recalculateCallback,
|
||||||
PromptContext &ctx) = 0;
|
PromptContext &ctx) = 0;
|
||||||
virtual void setThreadCount(int32_t n_threads) {}
|
virtual void setThreadCount(int32_t n_threads) {}
|
||||||
virtual int32_t threadCount() { return 1; }
|
virtual int32_t threadCount() { return 1; }
|
||||||
|
@ -49,6 +49,11 @@ bool llmodel_isModelLoaded(llmodel_model model)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wrapper functions for the C callbacks
|
// Wrapper functions for the C callbacks
|
||||||
|
bool prompt_wrapper(int32_t token_id, void *user_data) {
|
||||||
|
llmodel_prompt_callback callback = reinterpret_cast<llmodel_prompt_callback>(user_data);
|
||||||
|
return callback(token_id);
|
||||||
|
}
|
||||||
|
|
||||||
bool response_wrapper(int32_t token_id, const std::string &response, void *user_data) {
|
bool response_wrapper(int32_t token_id, const std::string &response, void *user_data) {
|
||||||
llmodel_response_callback callback = reinterpret_cast<llmodel_response_callback>(user_data);
|
llmodel_response_callback callback = reinterpret_cast<llmodel_response_callback>(user_data);
|
||||||
return callback(token_id, response.c_str());
|
return callback(token_id, response.c_str());
|
||||||
@ -60,17 +65,20 @@ bool recalculate_wrapper(bool is_recalculating, void *user_data) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||||
llmodel_response_callback response,
|
llmodel_response_callback prompt_callback,
|
||||||
llmodel_recalculate_callback recalculate,
|
llmodel_response_callback response_callback,
|
||||||
|
llmodel_recalculate_callback recalculate_callback,
|
||||||
llmodel_prompt_context *ctx)
|
llmodel_prompt_context *ctx)
|
||||||
{
|
{
|
||||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||||
|
|
||||||
// Create std::function wrappers that call the C function pointers
|
// Create std::function wrappers that call the C function pointers
|
||||||
|
std::function<bool(int32_t)> prompt_func =
|
||||||
|
std::bind(&prompt_wrapper, std::placeholders::_1, reinterpret_cast<void*>(prompt_callback));
|
||||||
std::function<bool(int32_t, const std::string&)> response_func =
|
std::function<bool(int32_t, const std::string&)> response_func =
|
||||||
std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast<void*>(response));
|
std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast<void*>(response_callback));
|
||||||
std::function<bool(bool)> recalc_func =
|
std::function<bool(bool)> recalc_func =
|
||||||
std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast<void*>(recalculate));
|
std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast<void*>(recalculate_callback));
|
||||||
|
|
||||||
// Copy the C prompt context
|
// Copy the C prompt context
|
||||||
wrapper->promptContext.n_past = ctx->n_past;
|
wrapper->promptContext.n_past = ctx->n_past;
|
||||||
@ -85,7 +93,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
|||||||
wrapper->promptContext.contextErase = ctx->context_erase;
|
wrapper->promptContext.contextErase = ctx->context_erase;
|
||||||
|
|
||||||
// Call the C++ prompt method
|
// Call the C++ prompt method
|
||||||
wrapper->llModel->prompt(prompt, response_func, recalc_func, wrapper->promptContext);
|
wrapper->llModel->prompt(prompt, prompt_func, response_func, recalc_func, wrapper->promptContext);
|
||||||
|
|
||||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||||
// which involves no copies
|
// which involves no copies
|
||||||
|
@ -37,10 +37,17 @@ typedef struct {
|
|||||||
float context_erase; // percent of context to erase if we exceed the context window
|
float context_erase; // percent of context to erase if we exceed the context window
|
||||||
} llmodel_prompt_context;
|
} llmodel_prompt_context;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Callback type for prompt processing.
|
||||||
|
* @param token_id The token id of the prompt.
|
||||||
|
* @return a bool indicating whether the model should keep processing.
|
||||||
|
*/
|
||||||
|
typedef bool (*llmodel_prompt_callback)(int32_t token_id);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Callback type for response.
|
* Callback type for response.
|
||||||
* @param token_id The token id of the response.
|
* @param token_id The token id of the response.
|
||||||
* @param response The response string.
|
* @param response The response string. NOTE: a token_id of -1 indicates the string is an error string.
|
||||||
* @return a bool indicating whether the model should keep generating.
|
* @return a bool indicating whether the model should keep generating.
|
||||||
*/
|
*/
|
||||||
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
|
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
|
||||||
@ -95,13 +102,15 @@ bool llmodel_isModelLoaded(llmodel_model model);
|
|||||||
* Generate a response using the model.
|
* Generate a response using the model.
|
||||||
* @param model A pointer to the llmodel_model instance.
|
* @param model A pointer to the llmodel_model instance.
|
||||||
* @param prompt A string representing the input prompt.
|
* @param prompt A string representing the input prompt.
|
||||||
* @param response A callback function for handling the generated response.
|
* @param prompt_callback A callback function for handling the processing of prompt.
|
||||||
* @param recalculate A callback function for handling recalculation requests.
|
* @param response_callback A callback function for handling the generated response.
|
||||||
|
* @param recalculate_callback A callback function for handling recalculation requests.
|
||||||
* @param ctx A pointer to the llmodel_prompt_context structure.
|
* @param ctx A pointer to the llmodel_prompt_context structure.
|
||||||
*/
|
*/
|
||||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||||
llmodel_response_callback response,
|
llmodel_response_callback prompt_callback,
|
||||||
llmodel_recalculate_callback recalculate,
|
llmodel_response_callback response_callback,
|
||||||
|
llmodel_recalculate_callback recalculate_callback,
|
||||||
llmodel_prompt_context *ctx);
|
llmodel_prompt_context *ctx);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
Loading…
Reference in New Issue
Block a user