mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-21 03:20:59 +00:00
python: embedding cancel callback for nomic client dynamic mode (#2214)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -1 +1 @@
|
||||
from .gpt4all import Embed4All as Embed4All, GPT4All as GPT4All
|
||||
from .gpt4all import CancellationError as CancellationError, Embed4All as Embed4All, GPT4All as GPT4All
|
||||
|
@@ -9,7 +9,7 @@ import sys
|
||||
import threading
|
||||
from enum import Enum
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
import importlib.resources as importlib_resources
|
||||
@@ -22,6 +22,9 @@ if (3, 9) <= sys.version_info < (3, 11):
|
||||
else:
|
||||
from typing import TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
|
||||
|
||||
|
||||
@@ -95,6 +98,7 @@ llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
||||
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
|
||||
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
||||
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
||||
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
|
||||
|
||||
llmodel.llmodel_prompt.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
@@ -119,6 +123,7 @@ llmodel.llmodel_embed.argtypes = [
|
||||
ctypes.POINTER(ctypes.c_size_t),
|
||||
ctypes.c_bool,
|
||||
ctypes.c_bool,
|
||||
EmbCancelCallback,
|
||||
ctypes.POINTER(ctypes.c_char_p),
|
||||
]
|
||||
|
||||
@@ -155,6 +160,7 @@ llmodel.llmodel_has_gpu_device.restype = ctypes.c_bool
|
||||
|
||||
ResponseCallbackType = Callable[[int, str], bool]
|
||||
RawResponseCallbackType = Callable[[int, bytes], bool]
|
||||
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
|
||||
|
||||
|
||||
def empty_response_callback(token_id: int, response: str) -> bool:
|
||||
@@ -171,6 +177,10 @@ class EmbedResult(Generic[EmbeddingsType], TypedDict):
|
||||
n_prompt_tokens: int
|
||||
|
||||
|
||||
class CancellationError(Exception):
|
||||
"""raised when embedding is canceled"""
|
||||
|
||||
|
||||
class LLModel:
|
||||
"""
|
||||
Base class and universal wrapper for GPT4All language models
|
||||
@@ -323,19 +333,22 @@ class LLModel:
|
||||
|
||||
@overload
|
||||
def generate_embeddings(
|
||||
self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool,
|
||||
self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool, cancel_cb: EmbCancelCallbackType,
|
||||
) -> EmbedResult[list[float]]: ...
|
||||
@overload
|
||||
def generate_embeddings(
|
||||
self, text: list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
||||
cancel_cb: EmbCancelCallbackType,
|
||||
) -> EmbedResult[list[list[float]]]: ...
|
||||
@overload
|
||||
def generate_embeddings(
|
||||
self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
||||
cancel_cb: EmbCancelCallbackType,
|
||||
) -> EmbedResult[list[Any]]: ...
|
||||
|
||||
def generate_embeddings(
|
||||
self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
||||
cancel_cb: EmbCancelCallbackType,
|
||||
) -> EmbedResult[list[Any]]:
|
||||
if not text:
|
||||
raise ValueError("text must not be None or empty")
|
||||
@@ -343,7 +356,7 @@ class LLModel:
|
||||
if self.model is None:
|
||||
self._raise_closed()
|
||||
|
||||
if (single_text := isinstance(text, str)):
|
||||
if single_text := isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
# prepare input
|
||||
@@ -355,14 +368,22 @@ class LLModel:
|
||||
for i, t in enumerate(text):
|
||||
c_texts[i] = t.encode()
|
||||
|
||||
def wrap_cancel_cb(batch_sizes: ctypes.POINTER(ctypes.c_uint), n_batch: int, backend: bytes) -> bool:
|
||||
assert cancel_cb is not None
|
||||
return cancel_cb(batch_sizes[:n_batch], backend.decode())
|
||||
|
||||
cancel_cb_wrapper = EmbCancelCallback(0x0 if cancel_cb is None else wrap_cancel_cb)
|
||||
|
||||
# generate the embeddings
|
||||
embedding_ptr = llmodel.llmodel_embed(
|
||||
self.model, c_texts, ctypes.byref(embedding_size), c_prefix, dimensionality, ctypes.byref(token_count),
|
||||
do_mean, atlas, ctypes.byref(error),
|
||||
do_mean, atlas, cancel_cb_wrapper, ctypes.byref(error),
|
||||
)
|
||||
|
||||
if not embedding_ptr:
|
||||
msg = "(unknown error)" if error.value is None else error.value.decode()
|
||||
if msg == "operation was canceled":
|
||||
raise CancellationError(msg)
|
||||
raise RuntimeError(f'Failed to generate embeddings: {msg}')
|
||||
|
||||
# extract output
|
||||
|
@@ -19,7 +19,8 @@ from requests.exceptions import ChunkedEncodingError
|
||||
from tqdm import tqdm
|
||||
from urllib3.exceptions import IncompleteRead, ProtocolError
|
||||
|
||||
from ._pyllmodel import EmbedResult as EmbedResult, LLModel, ResponseCallbackType, empty_response_callback
|
||||
from ._pyllmodel import (CancellationError as CancellationError, EmbCancelCallbackType, EmbedResult as EmbedResult,
|
||||
LLModel, ResponseCallbackType, empty_response_callback)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self, TypeAlias
|
||||
@@ -72,34 +73,36 @@ class Embed4All:
|
||||
@overload
|
||||
def embed(
|
||||
self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
|
||||
return_dict: Literal[False] = ..., atlas: bool = ...,
|
||||
return_dict: Literal[False] = ..., atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
|
||||
) -> list[float]: ...
|
||||
@overload
|
||||
def embed(
|
||||
self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
|
||||
return_dict: Literal[False] = ..., atlas: bool = ...,
|
||||
return_dict: Literal[False] = ..., atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
|
||||
) -> list[list[float]]: ...
|
||||
@overload
|
||||
def embed(
|
||||
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
|
||||
long_text_mode: str = ..., return_dict: Literal[False] = ..., atlas: bool = ...,
|
||||
cancel_cb: EmbCancelCallbackType | None = ...,
|
||||
) -> list[Any]: ...
|
||||
|
||||
# return_dict=True
|
||||
@overload
|
||||
def embed(
|
||||
self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
|
||||
return_dict: Literal[True], atlas: bool = ...,
|
||||
return_dict: Literal[True], atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
|
||||
) -> EmbedResult[list[float]]: ...
|
||||
@overload
|
||||
def embed(
|
||||
self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ...,
|
||||
return_dict: Literal[True], atlas: bool = ...,
|
||||
return_dict: Literal[True], atlas: bool = ..., cancel_cb: EmbCancelCallbackType | None = ...,
|
||||
) -> EmbedResult[list[list[float]]]: ...
|
||||
@overload
|
||||
def embed(
|
||||
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
|
||||
long_text_mode: str = ..., return_dict: Literal[True], atlas: bool = ...,
|
||||
cancel_cb: EmbCancelCallbackType | None = ...,
|
||||
) -> EmbedResult[list[Any]]: ...
|
||||
|
||||
# return type unknown
|
||||
@@ -107,11 +110,13 @@ class Embed4All:
|
||||
def embed(
|
||||
self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ...,
|
||||
long_text_mode: str = ..., return_dict: bool = ..., atlas: bool = ...,
|
||||
cancel_cb: EmbCancelCallbackType | None = ...,
|
||||
) -> Any: ...
|
||||
|
||||
def embed(
|
||||
self, text: str | list[str], *, prefix: str | None = None, dimensionality: int | None = None,
|
||||
long_text_mode: str = "mean", return_dict: bool = False, atlas: bool = False,
|
||||
cancel_cb: EmbCancelCallbackType | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Generate one or more embeddings.
|
||||
@@ -127,10 +132,14 @@ class Embed4All:
|
||||
return_dict: Return the result as a dict that includes the number of prompt tokens processed.
|
||||
atlas: Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens
|
||||
with long_text_mode="mean" will raise an error. Disabled by default.
|
||||
cancel_cb: Called with arguments (batch_sizes, backend_name). Return true to cancel embedding.
|
||||
|
||||
Returns:
|
||||
With return_dict=False, an embedding or list of embeddings of your text(s).
|
||||
With return_dict=True, a dict with keys 'embeddings' and 'n_prompt_tokens'.
|
||||
|
||||
Raises:
|
||||
CancellationError: If cancel_cb returned True and embedding was canceled.
|
||||
"""
|
||||
if dimensionality is None:
|
||||
dimensionality = -1
|
||||
@@ -146,7 +155,7 @@ class Embed4All:
|
||||
do_mean = {"mean": True, "truncate": False}[long_text_mode]
|
||||
except KeyError:
|
||||
raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}")
|
||||
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas)
|
||||
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas, cancel_cb)
|
||||
return result if return_dict else result['embeddings']
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user