mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-05 10:30:29 +00:00
python: embedding cancel callback for nomic client dynamic mode (#2214)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -158,7 +158,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const
|
|||||||
|
|
||||||
struct LLamaPrivate {
|
struct LLamaPrivate {
|
||||||
const std::string modelPath;
|
const std::string modelPath;
|
||||||
bool modelLoaded;
|
bool modelLoaded = false;
|
||||||
int device = -1;
|
int device = -1;
|
||||||
llama_model *model = nullptr;
|
llama_model *model = nullptr;
|
||||||
llama_context *ctx = nullptr;
|
llama_context *ctx = nullptr;
|
||||||
@@ -166,12 +166,11 @@ struct LLamaPrivate {
|
|||||||
llama_context_params ctx_params;
|
llama_context_params ctx_params;
|
||||||
int64_t n_threads = 0;
|
int64_t n_threads = 0;
|
||||||
std::vector<LLModel::Token> end_tokens;
|
std::vector<LLModel::Token> end_tokens;
|
||||||
|
const char *backend_name = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
LLamaModel::LLamaModel()
|
LLamaModel::LLamaModel()
|
||||||
: d_ptr(new LLamaPrivate) {
|
: d_ptr(new LLamaPrivate) {}
|
||||||
d_ptr->modelLoaded = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// default hparams (LLaMA 7B)
|
// default hparams (LLaMA 7B)
|
||||||
struct llama_file_hparams {
|
struct llama_file_hparams {
|
||||||
@@ -291,6 +290,8 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
|||||||
d_ptr->model_params.progress_callback = &LLModel::staticProgressCallback;
|
d_ptr->model_params.progress_callback = &LLModel::staticProgressCallback;
|
||||||
d_ptr->model_params.progress_callback_user_data = this;
|
d_ptr->model_params.progress_callback_user_data = this;
|
||||||
|
|
||||||
|
d_ptr->backend_name = "cpu"; // default
|
||||||
|
|
||||||
#ifdef GGML_USE_KOMPUTE
|
#ifdef GGML_USE_KOMPUTE
|
||||||
if (d_ptr->device != -1) {
|
if (d_ptr->device != -1) {
|
||||||
d_ptr->model_params.main_gpu = d_ptr->device;
|
d_ptr->model_params.main_gpu = d_ptr->device;
|
||||||
@@ -301,6 +302,7 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
|||||||
|
|
||||||
if (llama_verbose()) {
|
if (llama_verbose()) {
|
||||||
std::cerr << "llama.cpp: using Metal" << std::endl;
|
std::cerr << "llama.cpp: using Metal" << std::endl;
|
||||||
|
d_ptr->backend_name = "metal";
|
||||||
}
|
}
|
||||||
|
|
||||||
// always fully offload on Metal
|
// always fully offload on Metal
|
||||||
@@ -364,6 +366,7 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
|||||||
#ifdef GGML_USE_KOMPUTE
|
#ifdef GGML_USE_KOMPUTE
|
||||||
if (usingGPUDevice() && ggml_vk_has_device()) {
|
if (usingGPUDevice() && ggml_vk_has_device()) {
|
||||||
std::cerr << "llama.cpp: using Vulkan on " << ggml_vk_current_device().name << std::endl;
|
std::cerr << "llama.cpp: using Vulkan on " << ggml_vk_current_device().name << std::endl;
|
||||||
|
d_ptr->backend_name = "kompute";
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -674,7 +677,7 @@ void LLamaModel::embed(
|
|||||||
|
|
||||||
void LLamaModel::embed(
|
void LLamaModel::embed(
|
||||||
const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality,
|
const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality,
|
||||||
size_t *tokenCount, bool doMean, bool atlas
|
size_t *tokenCount, bool doMean, bool atlas, LLModel::EmbedCancelCallback *cancelCb
|
||||||
) {
|
) {
|
||||||
if (!d_ptr->model)
|
if (!d_ptr->model)
|
||||||
throw std::logic_error("no model is loaded");
|
throw std::logic_error("no model is loaded");
|
||||||
@@ -712,7 +715,7 @@ void LLamaModel::embed(
|
|||||||
throw std::invalid_argument(ss.str());
|
throw std::invalid_argument(ss.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
embedInternal(texts, embeddings, *prefix, dimensionality, tokenCount, doMean, atlas, spec);
|
embedInternal(texts, embeddings, *prefix, dimensionality, tokenCount, doMean, atlas, cancelCb, spec);
|
||||||
}
|
}
|
||||||
|
|
||||||
// MD5 hash of "nomic empty"
|
// MD5 hash of "nomic empty"
|
||||||
@@ -730,7 +733,7 @@ double getL2NormScale(T *start, T *end) {
|
|||||||
|
|
||||||
void LLamaModel::embedInternal(
|
void LLamaModel::embedInternal(
|
||||||
const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
|
const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
|
||||||
size_t *tokenCount, bool doMean, bool atlas, const EmbModelSpec *spec
|
size_t *tokenCount, bool doMean, bool atlas, LLModel::EmbedCancelCallback *cancelCb, const EmbModelSpec *spec
|
||||||
) {
|
) {
|
||||||
typedef std::vector<LLModel::Token> TokenString;
|
typedef std::vector<LLModel::Token> TokenString;
|
||||||
static constexpr int32_t atlasMaxLength = 8192;
|
static constexpr int32_t atlasMaxLength = 8192;
|
||||||
@@ -822,6 +825,23 @@ void LLamaModel::embedInternal(
|
|||||||
}
|
}
|
||||||
inputs.clear();
|
inputs.clear();
|
||||||
|
|
||||||
|
if (cancelCb) {
|
||||||
|
// copy of batching code below, but just count tokens instead of running inference
|
||||||
|
unsigned nBatchTokens = 0;
|
||||||
|
std::vector<unsigned> batchSizes;
|
||||||
|
for (const auto &inp: batches) {
|
||||||
|
if (nBatchTokens + inp.batch.size() > n_batch) {
|
||||||
|
batchSizes.push_back(nBatchTokens);
|
||||||
|
nBatchTokens = 0;
|
||||||
|
}
|
||||||
|
nBatchTokens += inp.batch.size();
|
||||||
|
}
|
||||||
|
batchSizes.push_back(nBatchTokens);
|
||||||
|
if (cancelCb(batchSizes.data(), batchSizes.size(), d_ptr->backend_name)) {
|
||||||
|
throw std::runtime_error("operation was canceled");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// initialize batch
|
// initialize batch
|
||||||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||||
|
|
||||||
@@ -871,7 +891,7 @@ void LLamaModel::embedInternal(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// break into batches
|
// break into batches
|
||||||
for (auto &inp: batches) {
|
for (const auto &inp: batches) {
|
||||||
// encode if at capacity
|
// encode if at capacity
|
||||||
if (batch.n_tokens + inp.batch.size() > n_batch) {
|
if (batch.n_tokens + inp.batch.size() > n_batch) {
|
||||||
decode();
|
decode();
|
||||||
|
@@ -39,7 +39,8 @@ public:
|
|||||||
size_t embeddingSize() const override;
|
size_t embeddingSize() const override;
|
||||||
// user-specified prefix
|
// user-specified prefix
|
||||||
void embed(const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix,
|
void embed(const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix,
|
||||||
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
|
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false,
|
||||||
|
EmbedCancelCallback *cancelCb = nullptr) override;
|
||||||
// automatic prefix
|
// automatic prefix
|
||||||
void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality = -1,
|
void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality = -1,
|
||||||
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
|
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
|
||||||
@@ -61,7 +62,8 @@ protected:
|
|||||||
int32_t layerCount(std::string const &modelPath) const override;
|
int32_t layerCount(std::string const &modelPath) const override;
|
||||||
|
|
||||||
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
|
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
|
||||||
size_t *tokenCount, bool doMean, bool atlas, const EmbModelSpec *spec);
|
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,
|
||||||
|
const EmbModelSpec *spec);
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // LLAMAMODEL_H
|
#endif // LLAMAMODEL_H
|
||||||
|
@@ -105,12 +105,15 @@ public:
|
|||||||
bool special = false,
|
bool special = false,
|
||||||
std::string *fakeReply = nullptr);
|
std::string *fakeReply = nullptr);
|
||||||
|
|
||||||
|
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
|
||||||
|
|
||||||
virtual size_t embeddingSize() const {
|
virtual size_t embeddingSize() const {
|
||||||
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
|
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
|
||||||
}
|
}
|
||||||
// user-specified prefix
|
// user-specified prefix
|
||||||
virtual void embed(const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix,
|
virtual void embed(const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix,
|
||||||
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false);
|
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false,
|
||||||
|
EmbedCancelCallback *cancelCb = nullptr);
|
||||||
// automatic prefix
|
// automatic prefix
|
||||||
virtual void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval,
|
virtual void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval,
|
||||||
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false);
|
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false);
|
||||||
|
@@ -159,7 +159,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
|||||||
|
|
||||||
float *llmodel_embed(
|
float *llmodel_embed(
|
||||||
llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix, int dimensionality,
|
llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix, int dimensionality,
|
||||||
size_t *token_count, bool do_mean, bool atlas, const char **error
|
size_t *token_count, bool do_mean, bool atlas, llmodel_emb_cancel_callback cancel_cb, const char **error
|
||||||
) {
|
) {
|
||||||
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
||||||
|
|
||||||
@@ -185,7 +185,7 @@ float *llmodel_embed(
|
|||||||
if (prefix) { prefixStr = prefix; }
|
if (prefix) { prefixStr = prefix; }
|
||||||
|
|
||||||
embedding = new float[embd_size];
|
embedding = new float[embd_size];
|
||||||
wrapper->llModel->embed(textsVec, embedding, prefixStr, dimensionality, token_count, do_mean, atlas);
|
wrapper->llModel->embed(textsVec, embedding, prefixStr, dimensionality, token_count, do_mean, atlas, cancel_cb);
|
||||||
} catch (std::exception const &e) {
|
} catch (std::exception const &e) {
|
||||||
llmodel_set_error(error, e.what());
|
llmodel_set_error(error, e.what());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@@ -82,6 +82,15 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response
|
|||||||
*/
|
*/
|
||||||
typedef bool (*llmodel_recalculate_callback)(bool is_recalculating);
|
typedef bool (*llmodel_recalculate_callback)(bool is_recalculating);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Embedding cancellation callback for use with llmodel_embed.
|
||||||
|
* @param batch_sizes The number of tokens in each batch that will be embedded.
|
||||||
|
* @param n_batch The number of batches that will be embedded.
|
||||||
|
* @param backend The backend that will be used for embedding. One of "cpu", "kompute", or "metal".
|
||||||
|
* @return True to cancel llmodel_embed, false to continue.
|
||||||
|
*/
|
||||||
|
typedef bool (*llmodel_emb_cancel_callback)(unsigned *batch_sizes, unsigned n_batch, const char *backend);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a llmodel instance.
|
* Create a llmodel instance.
|
||||||
* Recognises correct model type from file at model_path
|
* Recognises correct model type from file at model_path
|
||||||
@@ -198,12 +207,14 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
|||||||
* truncate.
|
* truncate.
|
||||||
* @param atlas Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens with
|
* @param atlas Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens with
|
||||||
* long_text_mode="mean" will raise an error. Disabled by default.
|
* long_text_mode="mean" will raise an error. Disabled by default.
|
||||||
|
* @param cancel_cb Cancellation callback, or NULL. See the documentation of llmodel_emb_cancel_callback.
|
||||||
* @param error Return location for a malloc()ed string that will be set on error, or NULL.
|
* @param error Return location for a malloc()ed string that will be set on error, or NULL.
|
||||||
* @return A pointer to an array of floating point values passed to the calling method which then will
|
* @return A pointer to an array of floating point values passed to the calling method which then will
|
||||||
* be responsible for lifetime of this memory. NULL if an error occurred.
|
* be responsible for lifetime of this memory. NULL if an error occurred.
|
||||||
*/
|
*/
|
||||||
float *llmodel_embed(llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix,
|
float *llmodel_embed(llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix,
|
||||||
int dimensionality, size_t *token_count, bool do_mean, bool atlas, const char **error);
|
int dimensionality, size_t *token_count, bool do_mean, bool atlas,
|
||||||
|
llmodel_emb_cancel_callback cancel_cb, const char **error);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Frees the memory allocated by the llmodel_embedding function.
|
* Frees the memory allocated by the llmodel_embedding function.
|
||||||
|
@@ -270,7 +270,7 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
|||||||
|
|
||||||
void LLModel::embed(
|
void LLModel::embed(
|
||||||
const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality,
|
const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality,
|
||||||
size_t *tokenCount, bool doMean, bool atlas
|
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb
|
||||||
) {
|
) {
|
||||||
(void)texts;
|
(void)texts;
|
||||||
(void)embeddings;
|
(void)embeddings;
|
||||||
@@ -279,6 +279,7 @@ void LLModel::embed(
|
|||||||
(void)tokenCount;
|
(void)tokenCount;
|
||||||
(void)doMean;
|
(void)doMean;
|
||||||
(void)atlas;
|
(void)atlas;
|
||||||
|
(void)cancelCb;
|
||||||
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
|
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1 +1 @@
|
|||||||
from .gpt4all import Embed4All as Embed4All, GPT4All as GPT4All
|
from .gpt4all import CancellationError as CancellationError, Embed4All as Embed4All, GPT4All as GPT4All
|
||||||
|
@@ -9,7 +9,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload
|
||||||
|
|
||||||
if sys.version_info >= (3, 9):
|
if sys.version_info >= (3, 9):
|
||||||
import importlib.resources as importlib_resources
|
import importlib.resources as importlib_resources
|
||||||
@@ -22,6 +22,9 @@ if (3, 9) <= sys.version_info < (3, 11):
|
|||||||
else:
|
else:
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
|
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
|
||||||
|
|
||||||
|
|
||||||
@@ -95,6 +98,7 @@ llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
|||||||
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
|
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
|
||||||
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
||||||
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
||||||
|
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
|
||||||
|
|
||||||
llmodel.llmodel_prompt.argtypes = [
|
llmodel.llmodel_prompt.argtypes = [
|
||||||
ctypes.c_void_p,
|
ctypes.c_void_p,
|
||||||
@@ -119,6 +123,7 @@ llmodel.llmodel_embed.argtypes = [
|
|||||||
ctypes.POINTER(ctypes.c_size_t),
|
ctypes.POINTER(ctypes.c_size_t),
|
||||||
ctypes.c_bool,
|
ctypes.c_bool,
|
||||||
ctypes.c_bool,
|
ctypes.c_bool,
|
||||||
|
EmbCancelCallback,
|
||||||
ctypes.POINTER(ctypes.c_char_p),
|
ctypes.POINTER(ctypes.c_char_p),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -155,6 +160,7 @@ llmodel.llmodel_has_gpu_device.restype = ctypes.c_bool
|
|||||||
|
|
||||||
ResponseCallbackType = Callable[[int, str], bool]
|
ResponseCallbackType = Callable[[int, str], bool]
|
||||||
RawResponseCallbackType = Callable[[int, bytes], bool]
|
RawResponseCallbackType = Callable[[int, bytes], bool]
|
||||||
|
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
|
||||||
|
|
||||||
|
|
||||||
def empty_response_callback(token_id: int, response: str) -> bool:
|
def empty_response_callback(token_id: int, response: str) -> bool:
|
||||||
@@ -171,6 +177,10 @@ class EmbedResult(Generic[EmbeddingsType], TypedDict):
|
|||||||
n_prompt_tokens: int
|
n_prompt_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class CancellationError(Exception):
|
||||||
|
"""raised when embedding is canceled"""
|
||||||
|
|
||||||
|
|
||||||
class LLModel:
|
class LLModel:
|
||||||
"""
|
"""
|
||||||
Base class and universal wrapper for GPT4All language models
|
Base class and universal wrapper for GPT4All language models
|
||||||
@@ -323,19 +333,22 @@ class LLModel:
|
|||||||
|
|
||||||
@overload
|
@overload
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool,
|
self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool, cancel_cb: EmbCancelCallbackType,
|
||||||
) -> EmbedResult[list[float]]: ...
|
) -> EmbedResult[list[float]]: ...
|
||||||
@overload
|
@overload
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
self, text: list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
self, text: list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
||||||
|
cancel_cb: EmbCancelCallbackType,
|
||||||
) -> EmbedResult[list[list[float]]]: ...
|
) -> EmbedResult[list[list[float]]]: ...
|
||||||
@overload
|
@overload
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
||||||
|
cancel_cb: EmbCancelCallbackType,
|
||||||
) -> EmbedResult[list[Any]]: ...
|
) -> EmbedResult[list[Any]]: ...
|
||||||
|
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
||||||
|
cancel_cb: EmbCancelCallbackType,
|
||||||
) -> EmbedResult[list[Any]]:
|
) -> EmbedResult[list[Any]]:
|
||||||
if not text:
|
if not text:
|
||||||
raise ValueError("text must not be None or empty")
|
raise ValueError("text must not be None or empty")
|
||||||
@@ -343,7 +356,7 @@ class LLModel:
|
|||||||
if self.model is None:
|
if self.model is None:
|
||||||
self._raise_closed()
|
self._raise_closed()
|
||||||
|
|
||||||
if (single_text := isinstance(text, str)):
|
if single_text := isinstance(text, str):
|
||||||
text = [text]
|
text = [text]
|
||||||
|
|
||||||
# prepare input
|
# prepare input
|
||||||
@@ -355,14 +368,22 @@ class LLModel:
|
|||||||
for i, t in enumerate(text):
|
for i, t in enumerate(text):
|
||||||
c_texts[i] = t.encode()
|
c_texts[i] = t.encode()
|
||||||
|
|
||||||
|
def wrap_cancel_cb(batch_sizes: ctypes.POINTER(ctypes.c_uint), n_batch: int, backend: bytes) -> bool:
|
||||||
|
assert cancel_cb is not None
|
||||||
|
return cancel_cb(batch_sizes[:n_batch], backend.decode())
|
||||||
|
|
||||||
|
cancel_cb_wrapper = EmbCancelCallback(0x0 if cancel_cb is None else wrap_cancel_cb)
|
||||||
|
|
||||||
# generate the embeddings
|
# generate the embeddings
|
||||||
embedding_ptr = llmodel.llmodel_embed(
|
embedding_ptr = llmodel.llmodel_embed(
|
||||||
self.model, c_texts, ctypes.byref(embedding_size), c_prefix, dimensionality, ctypes.byref(token_count),
|
self.model, c_texts, ctypes.byref(embedding_size), c_prefix, dimensionality, ctypes.byref(token_count),
|
||||||
do_mean, atlas, ctypes.byref(error),
|
do_mean, atlas, cancel_cb_wrapper, ctypes.byref(error),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not embedding_ptr:
|
if not embedding_ptr:
|
||||||
msg = "(unknown error)" if error.value is None else error.value.decode()
|
msg = "(unknown error)" if error.value is None else error.value.decode()
|
||||||
|
if msg == "operation was canceled":
|
||||||
|
raise CancellationError(msg)
|
||||||
raise RuntimeError(f'Failed to generate embeddings: {msg}')
|
raise RuntimeError(f'Failed to generate embeddings: {msg}')
|
||||||
|
|
||||||
# extract output
|
# extract output
|
||||||
|
@@ -19,7 +19,8 @@ from requests.exceptions import ChunkedEncodingError
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from urllib3.exceptions import IncompleteRead, ProtocolError
|
from urllib3.exceptions import IncompleteRead, ProtocolError
|
||||||
|
|
||||||
from ._pyllmodel import EmbedResult as EmbedResult, LLModel, ResponseCallbackType, empty_response_callback
|
from ._pyllmodel import (CancellationError as CancellationError, EmbCancelCallbackType, EmbedResult as EmbedResult,
|
||||||
|
LLModel, ResponseCallbackType, empty_response_callback)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import Self, TypeAlias
|
from typing_extensions import Self, TypeAlias
|
||||||
@@ -72,34 +73,36 @@ class Embed4All:
|
|||||||
@overload
|
@overload
|
||||||
def embed(
|
def embed(
|
||||||
self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
|
self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
|
||||||
return_dict: Literal[False] = ..., atlas: bool = ...,
|
return_dict: Literal[False] = ..., atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
|
||||||
) -> list[float]: ...
|
) -> list[float]: ...
|
||||||
@overload
|
@overload
|
||||||
def embed(
|
def embed(
|
||||||
self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
|
self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
|
||||||
return_dict: Literal[False] = ..., atlas: bool = ...,
|
return_dict: Literal[False] = ..., atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
|
||||||
) -> list[list[float]]: ...
|
) -> list[list[float]]: ...
|
||||||
@overload
|
@overload
|
||||||
def embed(
|
def embed(
|
||||||
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
|
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
|
||||||
long_text_mode: str = ..., return_dict: Literal[False] = ..., atlas: bool = ...,
|
long_text_mode: str = ..., return_dict: Literal[False] = ..., atlas: bool = ...,
|
||||||
|
cancel_cb: EmbCancelCallbackType | None = ...,
|
||||||
) -> list[Any]: ...
|
) -> list[Any]: ...
|
||||||
|
|
||||||
# return_dict=True
|
# return_dict=True
|
||||||
@overload
|
@overload
|
||||||
def embed(
|
def embed(
|
||||||
self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
|
self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
|
||||||
return_dict: Literal[True], atlas: bool = ...,
|
return_dict: Literal[True], atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
|
||||||
) -> EmbedResult[list[float]]: ...
|
) -> EmbedResult[list[float]]: ...
|
||||||
@overload
|
@overload
|
||||||
def embed(
|
def embed(
|
||||||
self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
|
self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
|
||||||
return_dict: Literal[True], atlas: bool = ...,
|
return_dict: Literal[True], atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
|
||||||
) -> EmbedResult[list[list[float]]]: ...
|
) -> EmbedResult[list[list[float]]]: ...
|
||||||
@overload
|
@overload
|
||||||
def embed(
|
def embed(
|
||||||
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
|
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
|
||||||
long_text_mode: str = ..., return_dict: Literal[True], atlas: bool = ...,
|
long_text_mode: str = ..., return_dict: Literal[True], atlas: bool = ...,
|
||||||
|
cancel_cb: EmbCancelCallbackType | None = ...,
|
||||||
) -> EmbedResult[list[Any]]: ...
|
) -> EmbedResult[list[Any]]: ...
|
||||||
|
|
||||||
# return type unknown
|
# return type unknown
|
||||||
@@ -107,11 +110,13 @@ class Embed4All:
|
|||||||
def embed(
|
def embed(
|
||||||
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
|
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
|
||||||
long_text_mode: str = ..., return_dict: bool = ..., atlas: bool = ...,
|
long_text_mode: str = ..., return_dict: bool = ..., atlas: bool = ...,
|
||||||
|
cancel_cb: EmbCancelCallbackType | None = ...,
|
||||||
) -> Any: ...
|
) -> Any: ...
|
||||||
|
|
||||||
def embed(
|
def embed(
|
||||||
self, text: str | list[str], *, prefix: str | None = None, dimensionality: int | None = None,
|
self, text: str | list[str], *, prefix: str | None = None, dimensionality: int | None = None,
|
||||||
long_text_mode: str = "mean", return_dict: bool = False, atlas: bool = False,
|
long_text_mode: str = "mean", return_dict: bool = False, atlas: bool = False,
|
||||||
|
cancel_cb: EmbCancelCallbackType | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Generate one or more embeddings.
|
Generate one or more embeddings.
|
||||||
@@ -127,10 +132,14 @@ class Embed4All:
|
|||||||
return_dict: Return the result as a dict that includes the number of prompt tokens processed.
|
return_dict: Return the result as a dict that includes the number of prompt tokens processed.
|
||||||
atlas: Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens
|
atlas: Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens
|
||||||
with long_text_mode="mean" will raise an error. Disabled by default.
|
with long_text_mode="mean" will raise an error. Disabled by default.
|
||||||
|
cancel_cb: Called with arguments (batch_sizes, backend_name). Return true to cancel embedding.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
With return_dict=False, an embedding or list of embeddings of your text(s).
|
With return_dict=False, an embedding or list of embeddings of your text(s).
|
||||||
With return_dict=True, a dict with keys 'embeddings' and 'n_prompt_tokens'.
|
With return_dict=True, a dict with keys 'embeddings' and 'n_prompt_tokens'.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CancellationError: If cancel_cb returned True and embedding was canceled.
|
||||||
"""
|
"""
|
||||||
if dimensionality is None:
|
if dimensionality is None:
|
||||||
dimensionality = -1
|
dimensionality = -1
|
||||||
@@ -146,7 +155,7 @@ class Embed4All:
|
|||||||
do_mean = {"mean": True, "truncate": False}[long_text_mode]
|
do_mean = {"mean": True, "truncate": False}[long_text_mode]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}")
|
raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}")
|
||||||
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas)
|
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas, cancel_cb)
|
||||||
return result if return_dict else result['embeddings']
|
return result if return_dict else result['embeddings']
|
||||||
|
|
||||||
|
|
||||||
|
@@ -68,7 +68,7 @@ def get_long_description():
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=package_name,
|
name=package_name,
|
||||||
version="2.4.1",
|
version="2.5.0",
|
||||||
description="Python bindings for GPT4All",
|
description="Python bindings for GPT4All",
|
||||||
long_description=get_long_description(),
|
long_description=get_long_description(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
@@ -258,7 +258,7 @@ Napi::Value NodeModelWrapper::GenerateEmbedding(const Napi::CallbackInfo &info)
|
|||||||
const char *_err = nullptr;
|
const char *_err = nullptr;
|
||||||
float *embeds = llmodel_embed(GetInference(), str_ptrs.data(), &embedding_size,
|
float *embeds = llmodel_embed(GetInference(), str_ptrs.data(), &embedding_size,
|
||||||
prefix.IsUndefined() ? nullptr : prefix.As<Napi::String>().Utf8Value().c_str(),
|
prefix.IsUndefined() ? nullptr : prefix.As<Napi::String>().Utf8Value().c_str(),
|
||||||
dimensionality, &token_count, do_mean, atlas, &_err);
|
dimensionality, &token_count, do_mean, atlas, nullptr, &_err);
|
||||||
if (!embeds)
|
if (!embeds)
|
||||||
{
|
{
|
||||||
// i dont wanna deal with c strings lol
|
// i dont wanna deal with c strings lol
|
||||||
|
Reference in New Issue
Block a user