python: embedding cancel callback for nomic client dynamic mode (#2214)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-04-12 16:00:39 -04:00
committed by GitHub
parent 459289b94c
commit 46818e466e
11 changed files with 95 additions and 28 deletions

View File

@@ -1 +1 @@
from .gpt4all import Embed4All as Embed4All, GPT4All as GPT4All
from .gpt4all import CancellationError as CancellationError, Embed4All as Embed4All, GPT4All as GPT4All

View File

@@ -9,7 +9,7 @@ import sys
import threading
from enum import Enum
from queue import Queue
from typing import Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload
if sys.version_info >= (3, 9):
import importlib.resources as importlib_resources
@@ -22,6 +22,9 @@ if (3, 9) <= sys.version_info < (3, 11):
else:
from typing import TypedDict
if TYPE_CHECKING:
from typing_extensions import TypeAlias
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
@@ -95,6 +98,7 @@ llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
llmodel.llmodel_prompt.argtypes = [
ctypes.c_void_p,
@@ -119,6 +123,7 @@ llmodel.llmodel_embed.argtypes = [
ctypes.POINTER(ctypes.c_size_t),
ctypes.c_bool,
ctypes.c_bool,
EmbCancelCallback,
ctypes.POINTER(ctypes.c_char_p),
]
@@ -155,6 +160,7 @@ llmodel.llmodel_has_gpu_device.restype = ctypes.c_bool
ResponseCallbackType = Callable[[int, str], bool]
RawResponseCallbackType = Callable[[int, bytes], bool]
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
def empty_response_callback(token_id: int, response: str) -> bool:
@@ -171,6 +177,10 @@ class EmbedResult(Generic[EmbeddingsType], TypedDict):
n_prompt_tokens: int
class CancellationError(Exception):
"""raised when embedding is canceled"""
class LLModel:
"""
Base class and universal wrapper for GPT4All language models
@@ -323,19 +333,22 @@ class LLModel:
@overload
def generate_embeddings(
self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool,
self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool, cancel_cb: EmbCancelCallbackType,
) -> EmbedResult[list[float]]: ...
@overload
def generate_embeddings(
self, text: list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
cancel_cb: EmbCancelCallbackType,
) -> EmbedResult[list[list[float]]]: ...
@overload
def generate_embeddings(
self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
cancel_cb: EmbCancelCallbackType,
) -> EmbedResult[list[Any]]: ...
def generate_embeddings(
self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
cancel_cb: EmbCancelCallbackType,
) -> EmbedResult[list[Any]]:
if not text:
raise ValueError("text must not be None or empty")
@@ -343,7 +356,7 @@ class LLModel:
if self.model is None:
self._raise_closed()
if (single_text := isinstance(text, str)):
if single_text := isinstance(text, str):
text = [text]
# prepare input
@@ -355,14 +368,22 @@ class LLModel:
for i, t in enumerate(text):
c_texts[i] = t.encode()
def wrap_cancel_cb(batch_sizes: ctypes.POINTER(ctypes.c_uint), n_batch: int, backend: bytes) -> bool:
assert cancel_cb is not None
return cancel_cb(batch_sizes[:n_batch], backend.decode())
cancel_cb_wrapper = EmbCancelCallback(0x0 if cancel_cb is None else wrap_cancel_cb)
# generate the embeddings
embedding_ptr = llmodel.llmodel_embed(
self.model, c_texts, ctypes.byref(embedding_size), c_prefix, dimensionality, ctypes.byref(token_count),
do_mean, atlas, ctypes.byref(error),
do_mean, atlas, cancel_cb_wrapper, ctypes.byref(error),
)
if not embedding_ptr:
msg = "(unknown error)" if error.value is None else error.value.decode()
if msg == "operation was canceled":
raise CancellationError(msg)
raise RuntimeError(f'Failed to generate embeddings: {msg}')
# extract output

View File

@@ -19,7 +19,8 @@ from requests.exceptions import ChunkedEncodingError
from tqdm import tqdm
from urllib3.exceptions import IncompleteRead, ProtocolError
from ._pyllmodel import EmbedResult as EmbedResult, LLModel, ResponseCallbackType, empty_response_callback
from ._pyllmodel import (CancellationError as CancellationError, EmbCancelCallbackType, EmbedResult as EmbedResult,
LLModel, ResponseCallbackType, empty_response_callback)
if TYPE_CHECKING:
from typing_extensions import Self, TypeAlias
@@ -72,34 +73,36 @@ class Embed4All:
@overload
def embed(
self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
return_dict: Literal[False] = ..., atlas: bool = ...,
return_dict: Literal[False] = ..., atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
) -> list[float]: ...
@overload
def embed(
self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
return_dict: Literal[False] = ..., atlas: bool = ...,
return_dict: Literal[False] = ..., atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
) -> list[list[float]]: ...
@overload
def embed(
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
long_text_mode: str = ..., return_dict: Literal[False] = ..., atlas: bool = ...,
cancel_cb: EmbCancelCallbackType | None = ...,
) -> list[Any]: ...
# return_dict=True
@overload
def embed(
self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
return_dict: Literal[True], atlas: bool = ...,
return_dict: Literal[True], atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
) -> EmbedResult[list[float]]: ...
@overload
def embed(
self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
return_dict: Literal[True], atlas: bool = ...,
return_dict: Literal[True], atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
) -> EmbedResult[list[list[float]]]: ...
@overload
def embed(
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
long_text_mode: str = ..., return_dict: Literal[True], atlas: bool = ...,
cancel_cb: EmbCancelCallbackType | None = ...,
) -> EmbedResult[list[Any]]: ...
# return type unknown
@@ -107,11 +110,13 @@ class Embed4All:
def embed(
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
long_text_mode: str = ..., return_dict: bool = ..., atlas: bool = ...,
cancel_cb: EmbCancelCallbackType | None = ...,
) -> Any: ...
def embed(
self, text: str | list[str], *, prefix: str | None = None, dimensionality: int | None = None,
long_text_mode: str = "mean", return_dict: bool = False, atlas: bool = False,
cancel_cb: EmbCancelCallbackType | None = None,
) -> Any:
"""
Generate one or more embeddings.
@@ -127,10 +132,14 @@ class Embed4All:
return_dict: Return the result as a dict that includes the number of prompt tokens processed.
atlas: Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens
with long_text_mode="mean" will raise an error. Disabled by default.
cancel_cb: Called with arguments (batch_sizes, backend_name). Return true to cancel embedding.
Returns:
With return_dict=False, an embedding or list of embeddings of your text(s).
With return_dict=True, a dict with keys 'embeddings' and 'n_prompt_tokens'.
Raises:
CancellationError: If cancel_cb returned True and embedding was canceled.
"""
if dimensionality is None:
dimensionality = -1
@@ -146,7 +155,7 @@ class Embed4All:
do_mean = {"mean": True, "truncate": False}[long_text_mode]
except KeyError:
raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}")
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas)
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas, cancel_cb)
return result if return_dict else result['embeddings']