diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index da2bc2b4..6cc630a4 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -113,7 +113,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt, llmodel_response_callback response_callback, llmodel_recalculate_callback recalculate_callback, llmodel_prompt_context *ctx, - bool special) + bool special, + const char *fake_reply) { LLModelWrapper *wrapper = reinterpret_cast(model); @@ -141,8 +142,13 @@ void llmodel_prompt(llmodel_model model, const char *prompt, wrapper->promptContext.repeat_last_n = ctx->repeat_last_n; wrapper->promptContext.contextErase = ctx->context_erase; + std::string fake_reply_str; + if (fake_reply) { fake_reply_str = fake_reply; } + auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr; + // Call the C++ prompt method - wrapper->llModel->prompt(prompt, prompt_template, prompt_func, response_func, recalc_func, wrapper->promptContext, special); + wrapper->llModel->prompt(prompt, prompt_template, prompt_func, response_func, recalc_func, wrapper->promptContext, + special, fake_reply_p); // Update the C context by giving access to the wrappers raw pointers to std::vector data // which involves no copies diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index a19bd083..3ae0db22 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -169,6 +169,7 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src); * @param response_callback A callback function for handling the generated response. * @param recalculate_callback A callback function for handling recalculation requests. * @param special True if special tokens in the prompt should be processed, false otherwise. + * @param fake_reply A string to insert into context as the model's reply, or NULL to generate one. * @param ctx A pointer to the llmodel_prompt_context structure. */ void llmodel_prompt(llmodel_model model, const char *prompt, @@ -177,7 +178,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt, llmodel_response_callback response_callback, llmodel_recalculate_callback recalculate_callback, llmodel_prompt_context *ctx, - bool special); + bool special, + const char *fake_reply); /** * Generate an embedding using the model. diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index fd9ff4b0..c623beda 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -100,6 +100,7 @@ llmodel.llmodel_prompt.argtypes = [ RecalculateCallback, ctypes.POINTER(LLModelPromptContext), ctypes.c_bool, + ctypes.c_char_p, ] llmodel.llmodel_prompt.restype = None @@ -361,6 +362,7 @@ class LLModel: RecalculateCallback(self._recalculate_callback), self.context, special, + ctypes.c_char_p(), )