mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-21 11:29:48 +00:00
python: documentation update and typing improvements (#2129)
Key changes: * revert "python: tweak constructor docstrings" * docs: update python GPT4All and Embed4All documentation * breaking: require keyword args to GPT4All.generate Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -9,7 +9,7 @@ import sys
|
||||
import threading
|
||||
from enum import Enum
|
||||
from queue import Queue
|
||||
from typing import Callable, Iterable, overload
|
||||
from typing import Any, Callable, Iterable, overload
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
import importlib.resources as importlib_resources
|
||||
@@ -295,15 +295,20 @@ class LLModel:
|
||||
) -> list[float]: ...
|
||||
@overload
|
||||
def generate_embeddings(
|
||||
self, text: list[str], prefix: str, dimensionality: int, do_mean: bool, atlas: bool,
|
||||
self, text: list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
||||
) -> list[list[float]]: ...
|
||||
@overload
|
||||
def generate_embeddings(
|
||||
self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
||||
) -> Any: ...
|
||||
|
||||
def generate_embeddings(self, text, prefix, dimensionality, do_mean, atlas):
|
||||
def generate_embeddings(
|
||||
self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
||||
) -> Any:
|
||||
if not text:
|
||||
raise ValueError("text must not be None or empty")
|
||||
|
||||
single_text = isinstance(text, str)
|
||||
if single_text:
|
||||
if (single_text := isinstance(text, str)):
|
||||
text = [text]
|
||||
|
||||
# prepare input
|
||||
|
@@ -10,7 +10,7 @@ import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Literal, overload
|
||||
|
||||
import requests
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
@@ -19,31 +19,35 @@ from urllib3.exceptions import IncompleteRead, ProtocolError
|
||||
|
||||
from . import _pyllmodel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
|
||||
# TODO: move to config
|
||||
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n"
|
||||
|
||||
ConfigType = Dict[str, str]
|
||||
MessageType = Dict[str, str]
|
||||
ConfigType: TypeAlias = 'dict[str, str]'
|
||||
MessageType: TypeAlias = 'dict[str, str]'
|
||||
|
||||
|
||||
class Embed4All:
|
||||
"""
|
||||
Python class that handles embeddings for GPT4All.
|
||||
|
||||
Args:
|
||||
model_name: The name of the embedding model to use. Defaults to `all-MiniLM-L6-v2.gguf2.f16.gguf`.
|
||||
|
||||
All other arguments are passed to the GPT4All constructor. See its documentation for more info.
|
||||
"""
|
||||
|
||||
MIN_DIMENSIONALITY = 64
|
||||
|
||||
def __init__(self, model_name: Optional[str] = None, **kwargs):
|
||||
def __init__(self, model_name: str | None = None, n_threads: int | None = None, **kwargs):
|
||||
"""
|
||||
Constructor
|
||||
|
||||
Args:
|
||||
n_threads: number of CPU threads used by GPT4All. Default is None, then the number of threads are determined automatically.
|
||||
"""
|
||||
if model_name is None:
|
||||
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
|
||||
self.gpt4all = GPT4All(model_name, **kwargs)
|
||||
self.gpt4all = GPT4All(model_name, n_threads=n_threads, **kwargs)
|
||||
|
||||
@overload
|
||||
def embed(
|
||||
@@ -56,18 +60,21 @@ class Embed4All:
|
||||
atlas: bool = ...,
|
||||
) -> list[list[float]]: ...
|
||||
|
||||
def embed(self, text, prefix=None, dimensionality=None, long_text_mode="mean", atlas=False):
|
||||
def embed(
|
||||
self, text: str | list[str], prefix: str | None = None, dimensionality: int | None = None,
|
||||
long_text_mode: str = "mean", atlas: bool = False,
|
||||
) -> list[Any]:
|
||||
"""
|
||||
Generate one or more embeddings.
|
||||
|
||||
Args:
|
||||
text: A text or list of texts to generate embeddings for.
|
||||
prefix: The model-specific prefix representing the embedding task, without the trailing colon. For Nomic
|
||||
Embed this can be `search_query`, `search_document`, `classification`, or `clustering`.
|
||||
Embed this can be `search_query`, `search_document`, `classification`, or `clustering`.
|
||||
dimensionality: The embedding dimension, for use with Matryoshka-capable models. Defaults to full-size.
|
||||
long_text_mode: How to handle texts longer than the model can accept. One of `mean` or `truncate`.
|
||||
atlas: Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens
|
||||
with long_text_mode="mean" will raise an error. Disabled by default.
|
||||
with long_text_mode="mean" will raise an error. Disabled by default.
|
||||
|
||||
Returns:
|
||||
An embedding or list of embeddings of your text(s).
|
||||
@@ -92,40 +99,43 @@ class Embed4All:
|
||||
class GPT4All:
|
||||
"""
|
||||
Python class that handles instantiation, downloading, generation and chat with GPT4All models.
|
||||
|
||||
Args:
|
||||
model_name: Name of GPT4All or custom model. Including ".gguf" file extension is optional but encouraged.
|
||||
model_path: Path to directory containing model file or, if file does not exist, where to download model.
|
||||
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
|
||||
model_type: Model architecture. This argument currently does not have any functionality and is just used as
|
||||
descriptive identifier for user. Default is None.
|
||||
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
||||
n_threads: number of CPU threads used by GPT4All. Default is None, then the number of threads are determined automatically.
|
||||
device: The processing unit on which the GPT4All model will run. It can be set to:
|
||||
- "cpu": Model will run on the central processing unit.
|
||||
- "gpu": Model will run on the best available graphics processing unit, irrespective of its vendor.
|
||||
- "amd", "nvidia", "intel": Model will run on the best available GPU from the specified vendor.
|
||||
Alternatively, a specific GPU name can also be provided, and the model will run on the GPU that matches the name if it's available.
|
||||
Default is "cpu".
|
||||
|
||||
Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model.
|
||||
n_ctx: Maximum size of context window
|
||||
ngl: Number of GPU layers to use (Vulkan)
|
||||
verbose: If True, print debug messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
model_path: Optional[Union[str, os.PathLike[str]]] = None,
|
||||
model_type: Optional[str] = None,
|
||||
model_path: str | os.PathLike[str] | None = None,
|
||||
model_type: str | None = None,
|
||||
allow_download: bool = True,
|
||||
n_threads: Optional[int] = None,
|
||||
device: Optional[str] = "cpu",
|
||||
n_threads: int | None = None,
|
||||
device: str | None = "cpu",
|
||||
n_ctx: int = 2048,
|
||||
ngl: int = 100,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Constructor
|
||||
|
||||
Args:
|
||||
model_name: Name of GPT4All or custom model. Including ".gguf" file extension is optional but encouraged.
|
||||
model_path: Path to directory containing model file or, if file does not exist, where to download model.
|
||||
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
|
||||
model_type: Model architecture. This argument currently does not have any functionality and is just used as
|
||||
descriptive identifier for user. Default is None.
|
||||
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
||||
n_threads: number of CPU threads used by GPT4All. Default is None, then the number of threads are determined automatically.
|
||||
device: The processing unit on which the GPT4All model will run. It can be set to:
|
||||
- "cpu": Model will run on the central processing unit.
|
||||
- "gpu": Model will run on the best available graphics processing unit, irrespective of its vendor.
|
||||
- "amd", "nvidia", "intel": Model will run on the best available GPU from the specified vendor.
|
||||
Alternatively, a specific GPU name can also be provided, and the model will run on the GPU that matches the name if it's available.
|
||||
Default is "cpu".
|
||||
|
||||
Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model.
|
||||
n_ctx: Maximum size of context window
|
||||
ngl: Number of GPU layers to use (Vulkan)
|
||||
verbose: If True, print debug messages.
|
||||
"""
|
||||
self.model_type = model_type
|
||||
# Retrieve model and download if allowed
|
||||
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
|
||||
@@ -142,10 +152,10 @@ class GPT4All:
|
||||
|
||||
@property
|
||||
def current_chat_session(self) -> list[MessageType] | None:
|
||||
return self._history
|
||||
return None if self._history is None else list(self._history)
|
||||
|
||||
@staticmethod
|
||||
def list_models() -> List[ConfigType]:
|
||||
def list_models() -> list[ConfigType]:
|
||||
"""
|
||||
Fetch model list from https://gpt4all.io/models/models2.json.
|
||||
|
||||
@@ -161,7 +171,7 @@ class GPT4All:
|
||||
def retrieve_model(
|
||||
cls,
|
||||
model_name: str,
|
||||
model_path: Optional[Union[str, os.PathLike[str]]] = None,
|
||||
model_path: str | os.PathLike[str] | None = None,
|
||||
allow_download: bool = True,
|
||||
verbose: bool = False,
|
||||
) -> ConfigType:
|
||||
@@ -225,7 +235,7 @@ class GPT4All:
|
||||
model_filename: str,
|
||||
model_path: str | os.PathLike[str],
|
||||
verbose: bool = True,
|
||||
url: Optional[str] = None,
|
||||
url: str | None = None,
|
||||
) -> str | os.PathLike[str]:
|
||||
"""
|
||||
Download model from https://gpt4all.io.
|
||||
@@ -302,9 +312,29 @@ class GPT4All:
|
||||
print(f"Model downloaded to {str(download_path)!r}", file=sys.stderr)
|
||||
return download_path
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, prompt: str, *, max_tokens: int = ..., temp: float = ..., top_k: int = ..., top_p: float = ...,
|
||||
min_p: float = ..., repeat_penalty: float = ..., repeat_last_n: int = ..., n_batch: int = ...,
|
||||
n_predict: int | None = ..., streaming: Literal[False] = ..., callback: _pyllmodel.ResponseCallbackType = ...,
|
||||
) -> str: ...
|
||||
@overload
|
||||
def generate(
|
||||
self, prompt: str, *, max_tokens: int = ..., temp: float = ..., top_k: int = ..., top_p: float = ...,
|
||||
min_p: float = ..., repeat_penalty: float = ..., repeat_last_n: int = ..., n_batch: int = ...,
|
||||
n_predict: int | None = ..., streaming: Literal[True], callback: _pyllmodel.ResponseCallbackType = ...,
|
||||
) -> Iterable[str]: ...
|
||||
@overload
|
||||
def generate(
|
||||
self, prompt: str, *, max_tokens: int = ..., temp: float = ..., top_k: int = ..., top_p: float = ...,
|
||||
min_p: float = ..., repeat_penalty: float = ..., repeat_last_n: int = ..., n_batch: int = ...,
|
||||
n_predict: int | None = ..., streaming: bool, callback: _pyllmodel.ResponseCallbackType = ...,
|
||||
) -> Any: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
max_tokens: int = 200,
|
||||
temp: float = 0.7,
|
||||
top_k: int = 40,
|
||||
@@ -313,10 +343,10 @@ class GPT4All:
|
||||
repeat_penalty: float = 1.18,
|
||||
repeat_last_n: int = 64,
|
||||
n_batch: int = 8,
|
||||
n_predict: Optional[int] = None,
|
||||
n_predict: int | None = None,
|
||||
streaming: bool = False,
|
||||
callback: _pyllmodel.ResponseCallbackType = _pyllmodel.empty_response_callback,
|
||||
) -> Union[str, Iterable[str]]:
|
||||
) -> Any:
|
||||
"""
|
||||
Generate outputs from any GPT4All model.
|
||||
|
||||
@@ -339,7 +369,7 @@ class GPT4All:
|
||||
"""
|
||||
|
||||
# Preparing the model request
|
||||
generate_kwargs: Dict[str, Any] = dict(
|
||||
generate_kwargs: dict[str, Any] = dict(
|
||||
temp=temp,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
@@ -380,7 +410,7 @@ class GPT4All:
|
||||
generate_kwargs["reset_context"] = True
|
||||
|
||||
# Prepare the callback, process the model response
|
||||
output_collector: List[MessageType]
|
||||
output_collector: list[MessageType]
|
||||
output_collector = [
|
||||
{"content": ""}
|
||||
] # placeholder for the self._history if chat session is not activated
|
||||
@@ -391,7 +421,7 @@ class GPT4All:
|
||||
|
||||
def _callback_wrapper(
|
||||
callback: _pyllmodel.ResponseCallbackType,
|
||||
output_collector: List[MessageType],
|
||||
output_collector: list[MessageType],
|
||||
) -> _pyllmodel.ResponseCallbackType:
|
||||
def _callback(token_id: int, response: str) -> bool:
|
||||
nonlocal callback, output_collector
|
||||
@@ -458,7 +488,7 @@ class GPT4All:
|
||||
|
||||
def _format_chat_prompt_template(
|
||||
self,
|
||||
messages: List[MessageType],
|
||||
messages: list[MessageType],
|
||||
default_prompt_header: str = "",
|
||||
default_prompt_footer: str = "",
|
||||
) -> str:
|
||||
|
@@ -28,12 +28,8 @@ def test_inference():
|
||||
assert len(tokens) > 0
|
||||
|
||||
with model.chat_session():
|
||||
tokens = list(model.generate(prompt='hello', top_k=1, streaming=True))
|
||||
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
|
||||
|
||||
tokens = list(model.generate(prompt='write me a poem about dogs', top_k=1, streaming=True))
|
||||
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
|
||||
|
||||
model.generate(prompt='hello', top_k=1, streaming=True)
|
||||
model.generate(prompt='write me a poem about dogs', top_k=1, streaming=True)
|
||||
print(model.current_chat_session)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user