mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-26 15:31:55 +00:00
Prelim support for past context.
This commit is contained in:
parent
91a2602d93
commit
6ce4089c4f
25
gptj.cpp
25
gptj.cpp
@ -419,7 +419,7 @@ bool gptj_eval(
|
||||
|
||||
const int d_key = n_embd/n_head;
|
||||
|
||||
static size_t buf_size = 256u*1024*1024;
|
||||
static size_t buf_size = 1024u*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||
@ -670,8 +670,7 @@ bool GPTJ::isModelLoaded() const
|
||||
}
|
||||
|
||||
void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
int32_t n_predict, int32_t top_k, float top_p, float temp,
|
||||
int32_t n_batch) {
|
||||
PromptContext &ctx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) {
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
std::cerr << "GPT-J ERROR: prompt won't work with an unloaded model!\n";
|
||||
@ -679,32 +678,38 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
||||
}
|
||||
|
||||
const int64_t t_main_start_us = ggml_time_us();
|
||||
int n_past = 0;
|
||||
|
||||
int64_t t_sample_us = 0;
|
||||
int64_t t_predict_us = 0;
|
||||
int64_t t_prompt_us = 0;
|
||||
|
||||
std::vector<float> logits;
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt);
|
||||
|
||||
n_predict = std::min(n_predict, d_ptr->model.hparams.n_ctx - (int) embd_inp.size());
|
||||
ctx.n_past = std::min(ctx.n_past, 1024);
|
||||
// n_batch = embd_inp.size();
|
||||
|
||||
std::cout << "The past was: " << ctx.n_past;
|
||||
fflush(stdout);
|
||||
|
||||
std::vector<gpt_vocab::id> embd;
|
||||
std::vector<gpt_vocab::id> resp;
|
||||
|
||||
// determine the required inference memory per token:
|
||||
static bool initialized = false;
|
||||
size_t mem_per_token = 0;
|
||||
gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
||||
if (!initialized) {
|
||||
gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, ctx.logits, mem_per_token);
|
||||
initialized = true;
|
||||
}
|
||||
|
||||
for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) {
|
||||
// predict
|
||||
if (embd.size() > 0) {
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, n_past, embd, logits, mem_per_token)) {
|
||||
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, ctx.n_past, embd, ctx.logits, mem_per_token)) {
|
||||
std::cerr << "GPT-J ERROR: Failed to predict\n";
|
||||
return;
|
||||
}
|
||||
@ -712,7 +717,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
||||
t_predict_us += ggml_time_us() - t_start_us;
|
||||
}
|
||||
|
||||
n_past += embd.size();
|
||||
ctx.n_past += embd.size();
|
||||
embd.clear();
|
||||
resp.clear();
|
||||
|
||||
@ -728,7 +733,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
||||
{
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
id = gpt_sample_top_k_top_p(d_ptr->vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, d_ptr->rng);
|
||||
id = gpt_sample_top_k_top_p(d_ptr->vocab, ctx.logits.data() + (ctx.logits.size() - n_vocab), top_k, top_p, temp, d_ptr->rng);
|
||||
|
||||
t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
8
gptj.h
8
gptj.h
@ -13,9 +13,13 @@ public:
|
||||
|
||||
bool loadModel(const std::string &modelPath, std::istream &fin);
|
||||
bool isModelLoaded() const;
|
||||
struct PromptContext {
|
||||
std::vector<float> logits;
|
||||
int32_t n_past = 0; // number of tokens in past conversation
|
||||
};
|
||||
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f, float temp = 0.9f,
|
||||
int32_t n_batch = 9);
|
||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
|
||||
float temp = 0.9f, int32_t n_batch = 9);
|
||||
|
||||
private:
|
||||
GPTJPrivate *d_ptr;
|
||||
|
3
llm.cpp
3
llm.cpp
@ -75,7 +75,8 @@ bool GPTJObject::prompt(const QString &prompt)
|
||||
m_stopGenerating = false;
|
||||
auto func = std::bind(&GPTJObject::handleResponse, this, std::placeholders::_1);
|
||||
emit responseStarted();
|
||||
m_gptj->prompt(prompt.toStdString(), func);
|
||||
static GPTJ::PromptContext ctx;
|
||||
m_gptj->prompt(prompt.toStdString(), func, ctx, 4096 /*number of chars to predict*/);
|
||||
emit responseStopped();
|
||||
return true;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user