python: embedding cancel callback for nomic client dynamic mode (#2214)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-04-12 16:00:39 -04:00
committed by GitHub
parent 459289b94c
commit 46818e466e
11 changed files with 95 additions and 28 deletions

View File

@@ -158,7 +158,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const
struct LLamaPrivate { struct LLamaPrivate {
const std::string modelPath; const std::string modelPath;
bool modelLoaded; bool modelLoaded = false;
int device = -1; int device = -1;
llama_model *model = nullptr; llama_model *model = nullptr;
llama_context *ctx = nullptr; llama_context *ctx = nullptr;
@@ -166,12 +166,11 @@ struct LLamaPrivate {
llama_context_params ctx_params; llama_context_params ctx_params;
int64_t n_threads = 0; int64_t n_threads = 0;
std::vector<LLModel::Token> end_tokens; std::vector<LLModel::Token> end_tokens;
const char *backend_name = nullptr;
}; };
LLamaModel::LLamaModel() LLamaModel::LLamaModel()
: d_ptr(new LLamaPrivate) { : d_ptr(new LLamaPrivate) {}
d_ptr->modelLoaded = false;
}
// default hparams (LLaMA 7B) // default hparams (LLaMA 7B)
struct llama_file_hparams { struct llama_file_hparams {
@@ -291,6 +290,8 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
d_ptr->model_params.progress_callback = &LLModel::staticProgressCallback; d_ptr->model_params.progress_callback = &LLModel::staticProgressCallback;
d_ptr->model_params.progress_callback_user_data = this; d_ptr->model_params.progress_callback_user_data = this;
d_ptr->backend_name = "cpu"; // default
#ifdef GGML_USE_KOMPUTE #ifdef GGML_USE_KOMPUTE
if (d_ptr->device != -1) { if (d_ptr->device != -1) {
d_ptr->model_params.main_gpu = d_ptr->device; d_ptr->model_params.main_gpu = d_ptr->device;
@@ -301,6 +302,7 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
if (llama_verbose()) { if (llama_verbose()) {
std::cerr << "llama.cpp: using Metal" << std::endl; std::cerr << "llama.cpp: using Metal" << std::endl;
d_ptr->backend_name = "metal";
} }
// always fully offload on Metal // always fully offload on Metal
@@ -364,6 +366,7 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
#ifdef GGML_USE_KOMPUTE #ifdef GGML_USE_KOMPUTE
if (usingGPUDevice() && ggml_vk_has_device()) { if (usingGPUDevice() && ggml_vk_has_device()) {
std::cerr << "llama.cpp: using Vulkan on " << ggml_vk_current_device().name << std::endl; std::cerr << "llama.cpp: using Vulkan on " << ggml_vk_current_device().name << std::endl;
d_ptr->backend_name = "kompute";
} }
#endif #endif
@@ -674,7 +677,7 @@ void LLamaModel::embed(
void LLamaModel::embed( void LLamaModel::embed(
const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality, const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality,
size_t *tokenCount, bool doMean, bool atlas size_t *tokenCount, bool doMean, bool atlas, LLModel::EmbedCancelCallback *cancelCb
) { ) {
if (!d_ptr->model) if (!d_ptr->model)
throw std::logic_error("no model is loaded"); throw std::logic_error("no model is loaded");
@@ -712,7 +715,7 @@ void LLamaModel::embed(
throw std::invalid_argument(ss.str()); throw std::invalid_argument(ss.str());
} }
embedInternal(texts, embeddings, *prefix, dimensionality, tokenCount, doMean, atlas, spec); embedInternal(texts, embeddings, *prefix, dimensionality, tokenCount, doMean, atlas, cancelCb, spec);
} }
// MD5 hash of "nomic empty" // MD5 hash of "nomic empty"
@@ -730,7 +733,7 @@ double getL2NormScale(T *start, T *end) {
void LLamaModel::embedInternal( void LLamaModel::embedInternal(
const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality, const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
size_t *tokenCount, bool doMean, bool atlas, const EmbModelSpec *spec size_t *tokenCount, bool doMean, bool atlas, LLModel::EmbedCancelCallback *cancelCb, const EmbModelSpec *spec
) { ) {
typedef std::vector<LLModel::Token> TokenString; typedef std::vector<LLModel::Token> TokenString;
static constexpr int32_t atlasMaxLength = 8192; static constexpr int32_t atlasMaxLength = 8192;
@@ -822,6 +825,23 @@ void LLamaModel::embedInternal(
} }
inputs.clear(); inputs.clear();
if (cancelCb) {
// copy of batching code below, but just count tokens instead of running inference
unsigned nBatchTokens = 0;
std::vector<unsigned> batchSizes;
for (const auto &inp: batches) {
if (nBatchTokens + inp.batch.size() > n_batch) {
batchSizes.push_back(nBatchTokens);
nBatchTokens = 0;
}
nBatchTokens += inp.batch.size();
}
batchSizes.push_back(nBatchTokens);
if (cancelCb(batchSizes.data(), batchSizes.size(), d_ptr->backend_name)) {
throw std::runtime_error("operation was canceled");
}
}
// initialize batch // initialize batch
struct llama_batch batch = llama_batch_init(n_batch, 0, 1); struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
@@ -871,7 +891,7 @@ void LLamaModel::embedInternal(
}; };
// break into batches // break into batches
for (auto &inp: batches) { for (const auto &inp: batches) {
// encode if at capacity // encode if at capacity
if (batch.n_tokens + inp.batch.size() > n_batch) { if (batch.n_tokens + inp.batch.size() > n_batch) {
decode(); decode();

View File

@@ -39,7 +39,8 @@ public:
size_t embeddingSize() const override; size_t embeddingSize() const override;
// user-specified prefix // user-specified prefix
void embed(const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, void embed(const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix,
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override; int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false,
EmbedCancelCallback *cancelCb = nullptr) override;
// automatic prefix // automatic prefix
void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality = -1, void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality = -1,
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override; size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
@@ -61,7 +62,8 @@ protected:
int32_t layerCount(std::string const &modelPath) const override; int32_t layerCount(std::string const &modelPath) const override;
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality, void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
size_t *tokenCount, bool doMean, bool atlas, const EmbModelSpec *spec); size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,
const EmbModelSpec *spec);
}; };
#endif // LLAMAMODEL_H #endif // LLAMAMODEL_H

View File

@@ -105,12 +105,15 @@ public:
bool special = false, bool special = false,
std::string *fakeReply = nullptr); std::string *fakeReply = nullptr);
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
virtual size_t embeddingSize() const { virtual size_t embeddingSize() const {
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings"); throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
} }
// user-specified prefix // user-specified prefix
virtual void embed(const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, virtual void embed(const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix,
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false); int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false,
EmbedCancelCallback *cancelCb = nullptr);
// automatic prefix // automatic prefix
virtual void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, virtual void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval,
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false); int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false);

View File

@@ -159,7 +159,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
float *llmodel_embed( float *llmodel_embed(
llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix, int dimensionality, llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix, int dimensionality,
size_t *token_count, bool do_mean, bool atlas, const char **error size_t *token_count, bool do_mean, bool atlas, llmodel_emb_cancel_callback cancel_cb, const char **error
) { ) {
auto *wrapper = static_cast<LLModelWrapper *>(model); auto *wrapper = static_cast<LLModelWrapper *>(model);
@@ -185,7 +185,7 @@ float *llmodel_embed(
if (prefix) { prefixStr = prefix; } if (prefix) { prefixStr = prefix; }
embedding = new float[embd_size]; embedding = new float[embd_size];
wrapper->llModel->embed(textsVec, embedding, prefixStr, dimensionality, token_count, do_mean, atlas); wrapper->llModel->embed(textsVec, embedding, prefixStr, dimensionality, token_count, do_mean, atlas, cancel_cb);
} catch (std::exception const &e) { } catch (std::exception const &e) {
llmodel_set_error(error, e.what()); llmodel_set_error(error, e.what());
return nullptr; return nullptr;

View File

@@ -82,6 +82,15 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response
*/ */
typedef bool (*llmodel_recalculate_callback)(bool is_recalculating); typedef bool (*llmodel_recalculate_callback)(bool is_recalculating);
/**
* Embedding cancellation callback for use with llmodel_embed.
* @param batch_sizes The number of tokens in each batch that will be embedded.
* @param n_batch The number of batches that will be embedded.
* @param backend The backend that will be used for embedding. One of "cpu", "kompute", or "metal".
* @return True to cancel llmodel_embed, false to continue.
*/
typedef bool (*llmodel_emb_cancel_callback)(unsigned *batch_sizes, unsigned n_batch, const char *backend);
/** /**
* Create a llmodel instance. * Create a llmodel instance.
* Recognises correct model type from file at model_path * Recognises correct model type from file at model_path
@@ -198,12 +207,14 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
* truncate. * truncate.
* @param atlas Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens with * @param atlas Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens with
* long_text_mode="mean" will raise an error. Disabled by default. * long_text_mode="mean" will raise an error. Disabled by default.
* @param cancel_cb Cancellation callback, or NULL. See the documentation of llmodel_emb_cancel_callback.
* @param error Return location for a malloc()ed string that will be set on error, or NULL. * @param error Return location for a malloc()ed string that will be set on error, or NULL.
* @return A pointer to an array of floating point values passed to the calling method which then will * @return A pointer to an array of floating point values passed to the calling method which then will
* be responsible for lifetime of this memory. NULL if an error occurred. * be responsible for lifetime of this memory. NULL if an error occurred.
*/ */
float *llmodel_embed(llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix, float *llmodel_embed(llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix,
int dimensionality, size_t *token_count, bool do_mean, bool atlas, const char **error); int dimensionality, size_t *token_count, bool do_mean, bool atlas,
llmodel_emb_cancel_callback cancel_cb, const char **error);
/** /**
* Frees the memory allocated by the llmodel_embedding function. * Frees the memory allocated by the llmodel_embedding function.

View File

@@ -270,7 +270,7 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
void LLModel::embed( void LLModel::embed(
const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality, const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality,
size_t *tokenCount, bool doMean, bool atlas size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb
) { ) {
(void)texts; (void)texts;
(void)embeddings; (void)embeddings;
@@ -279,6 +279,7 @@ void LLModel::embed(
(void)tokenCount; (void)tokenCount;
(void)doMean; (void)doMean;
(void)atlas; (void)atlas;
(void)cancelCb;
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings"); throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
} }

View File

@@ -1 +1 @@
from .gpt4all import Embed4All as Embed4All, GPT4All as GPT4All from .gpt4all import CancellationError as CancellationError, Embed4All as Embed4All, GPT4All as GPT4All

View File

@@ -9,7 +9,7 @@ import sys
import threading import threading
from enum import Enum from enum import Enum
from queue import Queue from queue import Queue
from typing import Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload
if sys.version_info >= (3, 9): if sys.version_info >= (3, 9):
import importlib.resources as importlib_resources import importlib.resources as importlib_resources
@@ -22,6 +22,9 @@ if (3, 9) <= sys.version_info < (3, 11):
else: else:
from typing import TypedDict from typing import TypedDict
if TYPE_CHECKING:
from typing_extensions import TypeAlias
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]') EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
@@ -95,6 +98,7 @@ llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32) PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p) ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool) RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
llmodel.llmodel_prompt.argtypes = [ llmodel.llmodel_prompt.argtypes = [
ctypes.c_void_p, ctypes.c_void_p,
@@ -119,6 +123,7 @@ llmodel.llmodel_embed.argtypes = [
ctypes.POINTER(ctypes.c_size_t), ctypes.POINTER(ctypes.c_size_t),
ctypes.c_bool, ctypes.c_bool,
ctypes.c_bool, ctypes.c_bool,
EmbCancelCallback,
ctypes.POINTER(ctypes.c_char_p), ctypes.POINTER(ctypes.c_char_p),
] ]
@@ -155,6 +160,7 @@ llmodel.llmodel_has_gpu_device.restype = ctypes.c_bool
ResponseCallbackType = Callable[[int, str], bool] ResponseCallbackType = Callable[[int, str], bool]
RawResponseCallbackType = Callable[[int, bytes], bool] RawResponseCallbackType = Callable[[int, bytes], bool]
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
def empty_response_callback(token_id: int, response: str) -> bool: def empty_response_callback(token_id: int, response: str) -> bool:
@@ -171,6 +177,10 @@ class EmbedResult(Generic[EmbeddingsType], TypedDict):
n_prompt_tokens: int n_prompt_tokens: int
class CancellationError(Exception):
"""raised when embedding is canceled"""
class LLModel: class LLModel:
""" """
Base class and universal wrapper for GPT4All language models Base class and universal wrapper for GPT4All language models
@@ -323,19 +333,22 @@ class LLModel:
@overload @overload
def generate_embeddings( def generate_embeddings(
self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool, self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool, cancel_cb: EmbCancelCallbackType,
) -> EmbedResult[list[float]]: ... ) -> EmbedResult[list[float]]: ...
@overload @overload
def generate_embeddings( def generate_embeddings(
self, text: list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool, self, text: list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
cancel_cb: EmbCancelCallbackType,
) -> EmbedResult[list[list[float]]]: ... ) -> EmbedResult[list[list[float]]]: ...
@overload @overload
def generate_embeddings( def generate_embeddings(
self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool, self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
cancel_cb: EmbCancelCallbackType,
) -> EmbedResult[list[Any]]: ... ) -> EmbedResult[list[Any]]: ...
def generate_embeddings( def generate_embeddings(
self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool, self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
cancel_cb: EmbCancelCallbackType,
) -> EmbedResult[list[Any]]: ) -> EmbedResult[list[Any]]:
if not text: if not text:
raise ValueError("text must not be None or empty") raise ValueError("text must not be None or empty")
@@ -343,7 +356,7 @@ class LLModel:
if self.model is None: if self.model is None:
self._raise_closed() self._raise_closed()
if (single_text := isinstance(text, str)): if single_text := isinstance(text, str):
text = [text] text = [text]
# prepare input # prepare input
@@ -355,14 +368,22 @@ class LLModel:
for i, t in enumerate(text): for i, t in enumerate(text):
c_texts[i] = t.encode() c_texts[i] = t.encode()
def wrap_cancel_cb(batch_sizes: ctypes.POINTER(ctypes.c_uint), n_batch: int, backend: bytes) -> bool:
assert cancel_cb is not None
return cancel_cb(batch_sizes[:n_batch], backend.decode())
cancel_cb_wrapper = EmbCancelCallback(0x0 if cancel_cb is None else wrap_cancel_cb)
# generate the embeddings # generate the embeddings
embedding_ptr = llmodel.llmodel_embed( embedding_ptr = llmodel.llmodel_embed(
self.model, c_texts, ctypes.byref(embedding_size), c_prefix, dimensionality, ctypes.byref(token_count), self.model, c_texts, ctypes.byref(embedding_size), c_prefix, dimensionality, ctypes.byref(token_count),
do_mean, atlas, ctypes.byref(error), do_mean, atlas, cancel_cb_wrapper, ctypes.byref(error),
) )
if not embedding_ptr: if not embedding_ptr:
msg = "(unknown error)" if error.value is None else error.value.decode() msg = "(unknown error)" if error.value is None else error.value.decode()
if msg == "operation was canceled":
raise CancellationError(msg)
raise RuntimeError(f'Failed to generate embeddings: {msg}') raise RuntimeError(f'Failed to generate embeddings: {msg}')
# extract output # extract output

View File

@@ -19,7 +19,8 @@ from requests.exceptions import ChunkedEncodingError
from tqdm import tqdm from tqdm import tqdm
from urllib3.exceptions import IncompleteRead, ProtocolError from urllib3.exceptions import IncompleteRead, ProtocolError
from ._pyllmodel import EmbedResult as EmbedResult, LLModel, ResponseCallbackType, empty_response_callback from ._pyllmodel import (CancellationError as CancellationError, EmbCancelCallbackType, EmbedResult as EmbedResult,
LLModel, ResponseCallbackType, empty_response_callback)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self, TypeAlias from typing_extensions import Self, TypeAlias
@@ -72,34 +73,36 @@ class Embed4All:
@overload @overload
def embed( def embed(
self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ..., self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
return_dict: Literal[False] = ..., atlas: bool = ..., return_dict: Literal[False] = ..., atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
) -> list[float]: ... ) -> list[float]: ...
@overload @overload
def embed( def embed(
self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ..., self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
return_dict: Literal[False] = ..., atlas: bool = ..., return_dict: Literal[False] = ..., atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
) -> list[list[float]]: ... ) -> list[list[float]]: ...
@overload @overload
def embed( def embed(
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
long_text_mode: str = ..., return_dict: Literal[False] = ..., atlas: bool = ..., long_text_mode: str = ..., return_dict: Literal[False] = ..., atlas: bool = ...,
cancel_cb: EmbCancelCallbackType | None = ...,
) -> list[Any]: ... ) -> list[Any]: ...
# return_dict=True # return_dict=True
@overload @overload
def embed( def embed(
self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ..., self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
return_dict: Literal[True], atlas: bool = ..., return_dict: Literal[True], atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
) -> EmbedResult[list[float]]: ... ) -> EmbedResult[list[float]]: ...
@overload @overload
def embed( def embed(
self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ..., self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
return_dict: Literal[True], atlas: bool = ..., return_dict: Literal[True], atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
) -> EmbedResult[list[list[float]]]: ... ) -> EmbedResult[list[list[float]]]: ...
@overload @overload
def embed( def embed(
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
long_text_mode: str = ..., return_dict: Literal[True], atlas: bool = ..., long_text_mode: str = ..., return_dict: Literal[True], atlas: bool = ...,
cancel_cb: EmbCancelCallbackType | None = ...,
) -> EmbedResult[list[Any]]: ... ) -> EmbedResult[list[Any]]: ...
# return type unknown # return type unknown
@@ -107,11 +110,13 @@ class Embed4All:
def embed( def embed(
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
long_text_mode: str = ..., return_dict: bool = ..., atlas: bool = ..., long_text_mode: str = ..., return_dict: bool = ..., atlas: bool = ...,
cancel_cb: EmbCancelCallbackType | None = ...,
) -> Any: ... ) -> Any: ...
def embed( def embed(
self, text: str | list[str], *, prefix: str | None = None, dimensionality: int | None = None, self, text: str | list[str], *, prefix: str | None = None, dimensionality: int | None = None,
long_text_mode: str = "mean", return_dict: bool = False, atlas: bool = False, long_text_mode: str = "mean", return_dict: bool = False, atlas: bool = False,
cancel_cb: EmbCancelCallbackType | None = None,
) -> Any: ) -> Any:
""" """
Generate one or more embeddings. Generate one or more embeddings.
@@ -127,10 +132,14 @@ class Embed4All:
return_dict: Return the result as a dict that includes the number of prompt tokens processed. return_dict: Return the result as a dict that includes the number of prompt tokens processed.
atlas: Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens atlas: Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens
with long_text_mode="mean" will raise an error. Disabled by default. with long_text_mode="mean" will raise an error. Disabled by default.
cancel_cb: Called with arguments (batch_sizes, backend_name). Return true to cancel embedding.
Returns: Returns:
With return_dict=False, an embedding or list of embeddings of your text(s). With return_dict=False, an embedding or list of embeddings of your text(s).
With return_dict=True, a dict with keys 'embeddings' and 'n_prompt_tokens'. With return_dict=True, a dict with keys 'embeddings' and 'n_prompt_tokens'.
Raises:
CancellationError: If cancel_cb returned True and embedding was canceled.
""" """
if dimensionality is None: if dimensionality is None:
dimensionality = -1 dimensionality = -1
@@ -146,7 +155,7 @@ class Embed4All:
do_mean = {"mean": True, "truncate": False}[long_text_mode] do_mean = {"mean": True, "truncate": False}[long_text_mode]
except KeyError: except KeyError:
raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}") raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}")
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas) result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas, cancel_cb)
return result if return_dict else result['embeddings'] return result if return_dict else result['embeddings']

View File

@@ -68,7 +68,7 @@ def get_long_description():
setup( setup(
name=package_name, name=package_name,
version="2.4.1", version="2.5.0",
description="Python bindings for GPT4All", description="Python bindings for GPT4All",
long_description=get_long_description(), long_description=get_long_description(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@@ -258,7 +258,7 @@ Napi::Value NodeModelWrapper::GenerateEmbedding(const Napi::CallbackInfo &info)
const char *_err = nullptr; const char *_err = nullptr;
float *embeds = llmodel_embed(GetInference(), str_ptrs.data(), &embedding_size, float *embeds = llmodel_embed(GetInference(), str_ptrs.data(), &embedding_size,
prefix.IsUndefined() ? nullptr : prefix.As<Napi::String>().Utf8Value().c_str(), prefix.IsUndefined() ? nullptr : prefix.As<Napi::String>().Utf8Value().c_str(),
dimensionality, &token_count, do_mean, atlas, &_err); dimensionality, &token_count, do_mean, atlas, nullptr, &_err);
if (!embeds) if (!embeds)
{ {
// i dont wanna deal with c strings lol // i dont wanna deal with c strings lol