llmodel_c: expose fakeReply to the bindings (#2061)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-03-06 13:32:24 -05:00 committed by GitHub
parent be6d3bf9dc
commit c19b763e03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 3 deletions

View File

@ -113,7 +113,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
llmodel_prompt_context *ctx,
bool special)
bool special,
const char *fake_reply)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
@ -141,8 +142,13 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
std::string fake_reply_str;
if (fake_reply) { fake_reply_str = fake_reply; }
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
// Call the C++ prompt method
wrapper->llModel->prompt(prompt, prompt_template, prompt_func, response_func, recalc_func, wrapper->promptContext, special);
wrapper->llModel->prompt(prompt, prompt_template, prompt_func, response_func, recalc_func, wrapper->promptContext,
special, fake_reply_p);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies

View File

@ -169,6 +169,7 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
* @param response_callback A callback function for handling the generated response.
* @param recalculate_callback A callback function for handling recalculation requests.
* @param special True if special tokens in the prompt should be processed, false otherwise.
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
* @param ctx A pointer to the llmodel_prompt_context structure.
*/
void llmodel_prompt(llmodel_model model, const char *prompt,
@ -177,7 +178,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
llmodel_prompt_context *ctx,
bool special);
bool special,
const char *fake_reply);
/**
* Generate an embedding using the model.

View File

@ -100,6 +100,7 @@ llmodel.llmodel_prompt.argtypes = [
RecalculateCallback,
ctypes.POINTER(LLModelPromptContext),
ctypes.c_bool,
ctypes.c_char_p,
]
llmodel.llmodel_prompt.restype = None
@ -361,6 +362,7 @@ class LLModel:
RecalculateCallback(self._recalculate_callback),
self.context,
special,
ctypes.c_char_p(),
)