diff --git a/gptj.cpp b/gptj.cpp index 832d2652..fb882f95 100644 --- a/gptj.cpp +++ b/gptj.cpp @@ -687,14 +687,7 @@ void GPTJ::prompt(const std::string &prompt, std::function embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt); n_predict = std::min(n_predict, d_ptr->model.hparams.n_ctx - (int) embd_inp.size()); - ctx.n_past = std::min(ctx.n_past, 1024); -// n_batch = embd_inp.size(); - - std::cout << "The past was: " << ctx.n_past; - fflush(stdout); - - std::vector embd; - std::vector resp; + ctx.n_past = std::min(ctx.n_past, d_ptr->model.hparams.n_ctx); // determine the required inference memory per token: static bool initialized = false; @@ -704,69 +697,50 @@ void GPTJ::prompt(const std::string &prompt, std::function 0) { - const int64_t t_start_us = ggml_time_us(); + // process the prompt in batches + size_t i = 0; + const int64_t t_start_prompt_us = ggml_time_us(); + while (i < embd_inp.size()) { + size_t batch_end = std::min(i + n_batch, embd_inp.size()); + std::vector batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); + if (!gptj_eval(d_ptr->model, d_ptr->n_threads, ctx.n_past, batch, ctx.logits, mem_per_token)) { + std::cerr << "GPT-J ERROR: Failed to process prompt\n"; + return; + } + ctx.n_past += batch.size(); + i = batch_end; + } + t_prompt_us += ggml_time_us() - t_start_prompt_us; - if (!gptj_eval(d_ptr->model, d_ptr->n_threads, ctx.n_past, embd, ctx.logits, mem_per_token)) { - std::cerr << "GPT-J ERROR: Failed to predict\n"; - return; - } + // predict next tokens + int32_t totalPredictions = 0; + for (int i = 0; i < n_predict; i++) { - t_predict_us += ggml_time_us() - t_start_us; + // sample next token + const int n_vocab = d_ptr->model.hparams.n_vocab; + gpt_vocab::id id = 0; + { + const int64_t t_start_sample_us = ggml_time_us(); + id = gpt_sample_top_k_top_p(d_ptr->vocab, ctx.logits.data() + (ctx.logits.size() - n_vocab), + top_k, top_p, temp, d_ptr->rng); + t_sample_us += ggml_time_us() - t_start_sample_us; } - ctx.n_past += embd.size(); - embd.clear(); - resp.clear(); - - if (i >= embd_inp.size()) { - t_prompt_us += ggml_time_us() - t_main_start_us; - - // sample next token - - const int n_vocab = d_ptr->model.hparams.n_vocab; - - gpt_vocab::id id = 0; - - { - const int64_t t_start_sample_us = ggml_time_us(); - - id = gpt_sample_top_k_top_p(d_ptr->vocab, ctx.logits.data() + (ctx.logits.size() - n_vocab), top_k, top_p, temp, d_ptr->rng); - - t_sample_us += ggml_time_us() - t_start_sample_us; - } - - // add it to the context - embd.push_back(id); - if (id != 50256) - resp.push_back(id); - } else { - // if here, it means we are still processing the input prompt - for (int k = i; k < embd_inp.size(); k++) { - embd.push_back(embd_inp[k]); - if (embd.size() > n_batch) { - break; - } - } - i += embd.size() - 1; + const int64_t t_start_predict_us = ggml_time_us(); + if (!gptj_eval(d_ptr->model, d_ptr->n_threads, ctx.n_past, { id }, ctx.logits, mem_per_token)) { + std::cerr << "GPT-J ERROR: Failed to predict next token\n"; + return; } + t_predict_us += ggml_time_us() - t_start_predict_us; + ctx.n_past += 1; // display text - for (auto id : resp) { - if (!response(d_ptr->vocab.id_to_token[id])) - goto stop_generating; - } - - // end of text token - if (embd.back() == 50256) { + ++totalPredictions; + if (id == 50256 /*end of text*/ || !response(d_ptr->vocab.id_to_token[id])) break; - } } -stop_generating: -#if 0 +#if 1 // report timing { const int64_t t_main_end_us = ggml_time_us(); @@ -774,7 +748,7 @@ stop_generating: std::cout << "GPT-J INFO: mem per token = " << mem_per_token << " bytes\n"; std::cout << "GPT-J INFO: sample time = " << t_sample_us/1000.0f << " ms\n"; std::cout << "GPT-J INFO: prompt time = " << t_prompt_us/1000.0f << " ms\n"; - std::cout << "GPT-J INFO: predict time = " << t_predict_us/1000.0f << " ms / " << t_predict_us/1000.0f/n_past << " ms per token\n"; + std::cout << "GPT-J INFO: predict time = " << t_predict_us/1000.0f << " ms / " << t_predict_us/1000.0f/totalPredictions << " ms per token\n"; std::cout << "GPT-J INFO: total time = " << (t_main_end_us - t_main_start_us)/1000.0f << " ms\n"; fflush(stdout); } diff --git a/main.qml b/main.qml index f106efe5..ef126429 100644 --- a/main.qml +++ b/main.qml @@ -210,11 +210,9 @@ Window { chatModel.append({"name": qsTr("Prompt: "), "currentResponse": false, "value": textInput.text}) chatModel.append({"name": qsTr("Response: "), "currentResponse": true, "value": "", "prompt": prompt}) -// var contextPrompt = "" -// for (var i = 0; i < chatModel.count; ++i) { -// var listElement = chatModel.get(i) -// contextPrompt += listElement.value + "\n"; -// } +// var contextPrompt; +// for (var i = 0; i < chatModel.count; ++i) +// contextPrompt += chatModel.get(i).value + "\n"; // prompt = contextPrompt + textInput.text + "\n" LLM.resetResponse()