Implement configurable context length (#1749)

This commit is contained in:
Jared Van Bortel
2023-12-16 17:58:15 -05:00
committed by GitHub
parent 7aa0f779de
commit d1c56b8b28
31 changed files with 291 additions and 135 deletions

View File

@@ -676,7 +676,8 @@ GPTJ::GPTJ()
d_ptr->modelLoaded = false;
}
size_t GPTJ::requiredMem(const std::string &modelPath) {
size_t GPTJ::requiredMem(const std::string &modelPath, int n_ctx) {
(void)n_ctx;
gptj_model dummy_model;
gpt_vocab dummy_vocab;
size_t mem_req;
@@ -684,7 +685,8 @@ size_t GPTJ::requiredMem(const std::string &modelPath) {
return mem_req;
}
bool GPTJ::loadModel(const std::string &modelPath) {
bool GPTJ::loadModel(const std::string &modelPath, int n_ctx) {
(void)n_ctx;
std::mt19937 rng(time(NULL));
d_ptr->rng = rng;