mirror of
				https://github.com/nomic-ai/gpt4all.git
				synced 2025-10-31 13:51:43 +00:00 
			
		
		
		
	Python Bindings: Improved unit tests, documentation and unification of API (#1090)
* Makefiles, black, isort * Black and isort * unit tests and generation method * chat context provider * context does not reset * Current state * Fixup * Python bindings with unit tests * GPT4All Python Bindings: chat contexts, tests * New python bindings and backend fixes * Black and Isort * Documentation error * preserved n_predict for backwords compat with langchain --------- Co-authored-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
		| @@ -1,2 +1,2 @@ | ||||
| from .pyllmodel import LLModel # noqa | ||||
| from .gpt4all import GPT4All # noqa | ||||
| from .gpt4all import GPT4All  # noqa | ||||
| from .pyllmodel import LLModel  # noqa | ||||
|   | ||||
| @@ -2,9 +2,10 @@ | ||||
| Python only API for running all GPT4All models. | ||||
| """ | ||||
| import os | ||||
| from pathlib import Path | ||||
| import time | ||||
| from typing import Dict, List | ||||
| from contextlib import contextmanager | ||||
| from pathlib import Path | ||||
| from typing import Dict, Iterable, List, Union | ||||
|  | ||||
| import requests | ||||
| from tqdm import tqdm | ||||
| @@ -15,14 +16,14 @@ from . import pyllmodel | ||||
| DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\") | ||||
|  | ||||
|  | ||||
| class GPT4All(): | ||||
|     """Python API for retrieving and interacting with GPT4All models. | ||||
|      | ||||
|     Attributes: | ||||
|         model: Pointer to underlying C model. | ||||
| class GPT4All: | ||||
|     """ | ||||
|     Python class that handles instantiation, downloading, generation and chat with GPT4All models. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download = True, n_threads = None): | ||||
|     def __init__( | ||||
|         self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True, n_threads=None | ||||
|     ): | ||||
|         """ | ||||
|         Constructor | ||||
|  | ||||
| @@ -32,7 +33,7 @@ class GPT4All(): | ||||
|                 Default is None, in which case models will be stored in `~/.cache/gpt4all/`. | ||||
|             model_type: Model architecture. This argument currently does not have any functionality and is just used as | ||||
|                 descriptive identifier for user. Default is None. | ||||
|             allow_download: Allow API to download models from gpt4all.io. Default is True.  | ||||
|             allow_download: Allow API to download models from gpt4all.io. Default is True. | ||||
|             n_threads: number of CPU threads used by GPT4All. Default is None, than the number of threads are determined automatically. | ||||
|         """ | ||||
|         self.model_type = model_type | ||||
| @@ -41,11 +42,14 @@ class GPT4All(): | ||||
|         model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download) | ||||
|         self.model.load_model(model_dest) | ||||
|         # Set n_threads | ||||
|         if n_threads != None: | ||||
|         if n_threads is not None: | ||||
|             self.model.set_thread_count(n_threads) | ||||
|  | ||||
|         self._is_chat_session_activated = False | ||||
|         self.current_chat_session = [] | ||||
|  | ||||
|     @staticmethod | ||||
|     def list_models(): | ||||
|     def list_models() -> Dict: | ||||
|         """ | ||||
|         Fetch model list from https://gpt4all.io/models/models.json. | ||||
|  | ||||
| @@ -55,8 +59,9 @@ class GPT4All(): | ||||
|         return requests.get("https://gpt4all.io/models/models.json").json() | ||||
|  | ||||
|     @staticmethod | ||||
|     def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True, | ||||
|                        verbose: bool = True) -> str: | ||||
|     def retrieve_model( | ||||
|         model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True | ||||
|     ) -> str: | ||||
|         """ | ||||
|         Find model file, and if it doesn't exist, download the model. | ||||
|  | ||||
| @@ -78,8 +83,10 @@ class GPT4All(): | ||||
|             try: | ||||
|                 os.makedirs(DEFAULT_MODEL_DIRECTORY, exist_ok=True) | ||||
|             except OSError as exc: | ||||
|                 raise ValueError(f"Failed to create model download directory at {DEFAULT_MODEL_DIRECTORY}: {exc}. " | ||||
|                                  "Please specify model_path.") | ||||
|                 raise ValueError( | ||||
|                     f"Failed to create model download directory at {DEFAULT_MODEL_DIRECTORY}: {exc}. " | ||||
|                     "Please specify model_path." | ||||
|                 ) | ||||
|             model_path = DEFAULT_MODEL_DIRECTORY | ||||
|         else: | ||||
|             model_path = model_path.replace("\\", "\\\\") | ||||
| @@ -108,7 +115,7 @@ class GPT4All(): | ||||
|                 raise ValueError(f"Model filename not in model list: {model_filename}") | ||||
|             url = selected_model.pop('url', None) | ||||
|  | ||||
|             return GPT4All.download_model(model_filename, model_path, verbose = verbose, url=url) | ||||
|             return GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url) | ||||
|         else: | ||||
|             raise ValueError("Failed to retrieve model") | ||||
|  | ||||
| @@ -126,6 +133,7 @@ class GPT4All(): | ||||
|         Returns: | ||||
|             Model file destination. | ||||
|         """ | ||||
|  | ||||
|         def get_download_url(model_filename): | ||||
|             if url: | ||||
|                 return url | ||||
| @@ -137,7 +145,7 @@ class GPT4All(): | ||||
|  | ||||
|         response = requests.get(download_url, stream=True) | ||||
|         total_size_in_bytes = int(response.headers.get("content-length", 0)) | ||||
|         block_size = 2 ** 20  # 1 MB | ||||
|         block_size = 2**20  # 1 MB | ||||
|  | ||||
|         with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar: | ||||
|             try: | ||||
| @@ -154,9 +162,7 @@ class GPT4All(): | ||||
|  | ||||
|         # Validate download was successful | ||||
|         if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | ||||
|             raise RuntimeError( | ||||
|                 "An error occurred during download. Downloaded file may not work." | ||||
|             ) | ||||
|             raise RuntimeError("An error occurred during download. Downloaded file may not work.") | ||||
|  | ||||
|         # Sleep for a little bit so OS can remove file lock | ||||
|         time.sleep(2) | ||||
| @@ -165,101 +171,83 @@ class GPT4All(): | ||||
|             print("Model downloaded at: ", download_path) | ||||
|         return download_path | ||||
|  | ||||
|     # TODO: this naming is just confusing now and needs to be deprecated now that we have generator | ||||
|     # Need to better consolidate all these different model response methods | ||||
|     def generate(self, prompt: str, streaming: bool = True, **generate_kwargs) -> str: | ||||
|     def generate( | ||||
|         self, | ||||
|         prompt: str, | ||||
|         max_tokens: int = 200, | ||||
|         temp: float = 0.7, | ||||
|         top_k: int = 40, | ||||
|         top_p: float = 0.1, | ||||
|         repeat_penalty: float = 1.18, | ||||
|         repeat_last_n: int = 64, | ||||
|         n_batch: int = 8, | ||||
|         n_predict: int = None, | ||||
|         streaming: bool = False, | ||||
|     ) -> Union[str, Iterable]: | ||||
|         """ | ||||
|         Surfaced method of running generate without accessing model object. | ||||
|         Generate outputs from any GPT4All model. | ||||
|  | ||||
|         Args: | ||||
|             prompt: Raw string to be passed to model. | ||||
|             streaming: True if want output streamed to stdout. | ||||
|             **generate_kwargs: Optional kwargs to pass to prompt context. | ||||
|          | ||||
|         Returns: | ||||
|             Raw string of generated model response. | ||||
|         """ | ||||
|         return self.model.prompt_model(prompt, streaming=streaming, **generate_kwargs) | ||||
|  | ||||
|     def generator(self, prompt: str, **generate_kwargs) -> str: | ||||
|         """ | ||||
|         Surfaced method of running generate without accessing model object. | ||||
|  | ||||
|         Args: | ||||
|             prompt: Raw string to be passed to model. | ||||
|             streaming: True if want output streamed to stdout. | ||||
|             **generate_kwargs: Optional kwargs to pass to prompt context. | ||||
|          | ||||
|         Returns: | ||||
|             Raw string of generated model response. | ||||
|         """ | ||||
|         return self.model.generator(prompt, **generate_kwargs) | ||||
|  | ||||
|     def chat_completion(self, | ||||
|                         messages: List[Dict], | ||||
|                         default_prompt_header: bool = True, | ||||
|                         default_prompt_footer: bool = True, | ||||
|                         verbose: bool = True, | ||||
|                         streaming: bool = True, | ||||
|                         **generate_kwargs) -> dict: | ||||
|         """ | ||||
|         Format list of message dictionaries into a prompt and call model | ||||
|         generate on prompt. Returns a response dictionary with metadata and | ||||
|         generated content. | ||||
|  | ||||
|         Args: | ||||
|             messages: List of dictionaries. Each dictionary should have a "role" key | ||||
|                 with value of "system", "assistant", or "user" and a "content" key with a | ||||
|                 string value. Messages are organized such that "system" messages are at top of prompt, | ||||
|                 and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as | ||||
|                 "Response: {content}".  | ||||
|             default_prompt_header: If True (default), add default prompt header after any system role messages and | ||||
|                 before user/assistant role messages. | ||||
|             default_prompt_footer: If True (default), add default footer at end of prompt. | ||||
|             verbose: If True (default), print full prompt and generated response. | ||||
|             streaming: True if want output streamed to stdout. | ||||
|             **generate_kwargs: Optional kwargs to pass to prompt context. | ||||
|             prompt: The prompt for the model the complete. | ||||
|             max_tokens: The maximum number of tokens to generate. | ||||
|             temp: The model temperature. Larger values increase creativity but decrease factuality. | ||||
|             top_k: Randomly sample from the top_k most likely tokens at each generation step. Set this to 1 for greedy decoding. | ||||
|             top_p: Randomly sample at each generation step from the top most likely tokens whose probabilities add up to top_p. | ||||
|             repeat_penalty: Penalize the model for repetition. Higher values result in less repetition. | ||||
|             repeat_last_n: How far in the models generation history to apply the repeat penalty. | ||||
|             n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements. | ||||
|             n_predict: Equivalent to max_tokens, exists for backwards compatability. | ||||
|             streaming: If True, this method will instead return a generator that yields tokens as the model generates them. | ||||
|  | ||||
|         Returns: | ||||
|             Response dictionary with:   | ||||
|                 "model": name of model.   | ||||
|                 "usage": a dictionary with number of full prompt tokens, number of  | ||||
|                     generated tokens in response, and total tokens.   | ||||
|                 "choices": List of message dictionary where "content" is generated response and "role" is set | ||||
|                 as "assistant". Right now, only one choice is returned by model. | ||||
|             Either the entire completion or a generator that yields the completion token by token. | ||||
|         """ | ||||
|         full_prompt = self._build_prompt(messages, | ||||
|                                          default_prompt_header=default_prompt_header, | ||||
|                                          default_prompt_footer=default_prompt_footer) | ||||
|         if verbose: | ||||
|             print(full_prompt) | ||||
|         generate_kwargs = locals() | ||||
|         generate_kwargs.pop('self') | ||||
|         generate_kwargs.pop('max_tokens') | ||||
|         generate_kwargs.pop('streaming') | ||||
|         generate_kwargs['n_predict'] = max_tokens | ||||
|         if n_predict is not None: | ||||
|             generate_kwargs['n_predict'] = n_predict | ||||
|  | ||||
|         response = self.model.prompt_model(full_prompt, streaming=streaming, **generate_kwargs) | ||||
|         if streaming and self._is_chat_session_activated: | ||||
|             raise NotImplementedError("Streaming tokens in a chat session is not currently supported.") | ||||
|  | ||||
|         if verbose and not streaming: | ||||
|             print(response) | ||||
|         if self._is_chat_session_activated: | ||||
|             self.current_chat_session.append({"role": "user", "content": prompt}) | ||||
|             generate_kwargs['prompt'] = self._format_chat_prompt_template(messages=self.current_chat_session) | ||||
|             generate_kwargs['reset_context'] = len(self.current_chat_session) == 1 | ||||
|         else: | ||||
|             generate_kwargs['reset_context'] = True | ||||
|  | ||||
|         response_dict = { | ||||
|             "model": self.model.model_name, | ||||
|             "usage": {"prompt_tokens": len(full_prompt), | ||||
|                       "completion_tokens": len(response), | ||||
|                       "total_tokens": len(full_prompt) + len(response)}, | ||||
|             "choices": [ | ||||
|                 { | ||||
|                     "message": { | ||||
|                         "role": "assistant", | ||||
|                         "content": response | ||||
|                     } | ||||
|                 } | ||||
|             ] | ||||
|         } | ||||
|         if streaming: | ||||
|             return self.model.prompt_model_streaming(**generate_kwargs) | ||||
|  | ||||
|         return response_dict | ||||
|         output = self.model.prompt_model(**generate_kwargs) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _build_prompt(messages: List[Dict], | ||||
|                       default_prompt_header=True, | ||||
|                       default_prompt_footer=True) -> str: | ||||
|         if self._is_chat_session_activated: | ||||
|             self.current_chat_session.append({"role": "assistant", "content": output}) | ||||
|  | ||||
|         return output | ||||
|  | ||||
|     @contextmanager | ||||
|     def chat_session(self): | ||||
|         ''' | ||||
|         Context manager to hold an inference optimized chat session with a GPT4All model. | ||||
|         ''' | ||||
|         # Code to acquire resource, e.g.: | ||||
|         self._is_chat_session_activated = True | ||||
|         self._current_chat_session = [] | ||||
|         try: | ||||
|             yield self | ||||
|         finally: | ||||
|             # Code to release resource, e.g.: | ||||
|             self._is_chat_session_activated = False | ||||
|             self._current_chat_session = [] | ||||
|  | ||||
|     def _format_chat_prompt_template( | ||||
|         self, messages: List[Dict], default_prompt_header=True, default_prompt_footer=True | ||||
|     ) -> str: | ||||
|         """ | ||||
|         Helper method for building a prompt using template from list of messages. | ||||
|  | ||||
| @@ -269,37 +257,20 @@ class GPT4All(): | ||||
|                 string value. Messages are organized such that "system" messages are at top of prompt, | ||||
|                 and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as | ||||
|                 "Response: {content}". | ||||
|             default_prompt_header: If True (default), add default prompt header after any system role messages and | ||||
|                 before user/assistant role messages. | ||||
|             default_prompt_footer: If True (default), add default footer at end of prompt. | ||||
|          | ||||
|  | ||||
|         Returns: | ||||
|             Formatted prompt. | ||||
|         """ | ||||
|         full_prompt = "" | ||||
|  | ||||
|         for message in messages: | ||||
|             if message["role"] == "system": | ||||
|                 system_message = message["content"] + "\n" | ||||
|                 full_prompt += system_message | ||||
|  | ||||
|         if default_prompt_header: | ||||
|             full_prompt += """### Instruction:  | ||||
|             The prompt below is a question to answer, a task to complete, or a conversation  | ||||
|             to respond to; decide which and write an appropriate response. | ||||
|             \n### Prompt: """ | ||||
|  | ||||
|         for message in messages: | ||||
|             if message["role"] == "user": | ||||
|                 user_message = "\n" + message["content"] | ||||
|                 user_message = "### Human: \n" + message["content"] + "\n### Assistant:\n" | ||||
|                 full_prompt += user_message | ||||
|             if message["role"] == "assistant": | ||||
|                 assistant_message = "\n### Response: " + message["content"] | ||||
|                 assistant_message = message["content"] + '\n' | ||||
|                 full_prompt += assistant_message | ||||
|  | ||||
|         if default_prompt_footer: | ||||
|             full_prompt += "\n### Response:" | ||||
|  | ||||
|         return full_prompt | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
| import pkg_resources | ||||
| import ctypes | ||||
| import os | ||||
| import platform | ||||
| @@ -7,6 +6,10 @@ import re | ||||
| import subprocess | ||||
| import sys | ||||
| import threading | ||||
| from typing import Iterable | ||||
|  | ||||
| import pkg_resources | ||||
|  | ||||
|  | ||||
| class DualStreamProcessor: | ||||
|     def __init__(self, stream=None): | ||||
| @@ -19,10 +22,12 @@ class DualStreamProcessor: | ||||
|             self.stream.flush() | ||||
|         self.output += text | ||||
|  | ||||
|  | ||||
| # TODO: provide a config file to make this more robust | ||||
| LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\") | ||||
| MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\") | ||||
|  | ||||
|  | ||||
| def load_llmodel_library(): | ||||
|     system = platform.system() | ||||
|  | ||||
| @@ -40,34 +45,40 @@ def load_llmodel_library(): | ||||
|  | ||||
|     llmodel_file = "libllmodel" + '.' + c_lib_ext | ||||
|  | ||||
|     llmodel_dir = str(pkg_resources.resource_filename('gpt4all', \ | ||||
|         os.path.join(LLMODEL_PATH, llmodel_file))).replace("\\", "\\\\") | ||||
|     llmodel_dir = str(pkg_resources.resource_filename('gpt4all', os.path.join(LLMODEL_PATH, llmodel_file))).replace( | ||||
|         "\\", "\\\\" | ||||
|     ) | ||||
|  | ||||
|     llmodel_lib = ctypes.CDLL(llmodel_dir) | ||||
|  | ||||
|     return llmodel_lib | ||||
|  | ||||
|  | ||||
| llmodel = load_llmodel_library() | ||||
|  | ||||
|  | ||||
| class LLModelError(ctypes.Structure): | ||||
|     _fields_ = [("message", ctypes.c_char_p), | ||||
|                 ("code", ctypes.c_int32)] | ||||
|     _fields_ = [("message", ctypes.c_char_p), ("code", ctypes.c_int32)] | ||||
|  | ||||
|  | ||||
| class LLModelPromptContext(ctypes.Structure): | ||||
|     _fields_ = [("logits", ctypes.POINTER(ctypes.c_float)), | ||||
|                 ("logits_size", ctypes.c_size_t), | ||||
|                 ("tokens", ctypes.POINTER(ctypes.c_int32)), | ||||
|                 ("tokens_size", ctypes.c_size_t), | ||||
|                 ("n_past", ctypes.c_int32), | ||||
|                 ("n_ctx", ctypes.c_int32), | ||||
|                 ("n_predict", ctypes.c_int32), | ||||
|                 ("top_k", ctypes.c_int32), | ||||
|                 ("top_p", ctypes.c_float), | ||||
|                 ("temp", ctypes.c_float), | ||||
|                 ("n_batch", ctypes.c_int32), | ||||
|                 ("repeat_penalty", ctypes.c_float), | ||||
|                 ("repeat_last_n", ctypes.c_int32), | ||||
|                 ("context_erase", ctypes.c_float)] | ||||
|     _fields_ = [ | ||||
|         ("logits", ctypes.POINTER(ctypes.c_float)), | ||||
|         ("logits_size", ctypes.c_size_t), | ||||
|         ("tokens", ctypes.POINTER(ctypes.c_int32)), | ||||
|         ("tokens_size", ctypes.c_size_t), | ||||
|         ("n_past", ctypes.c_int32), | ||||
|         ("n_ctx", ctypes.c_int32), | ||||
|         ("n_predict", ctypes.c_int32), | ||||
|         ("top_k", ctypes.c_int32), | ||||
|         ("top_p", ctypes.c_float), | ||||
|         ("temp", ctypes.c_float), | ||||
|         ("n_batch", ctypes.c_int32), | ||||
|         ("repeat_penalty", ctypes.c_float), | ||||
|         ("repeat_last_n", ctypes.c_int32), | ||||
|         ("context_erase", ctypes.c_float), | ||||
|     ] | ||||
|  | ||||
|  | ||||
| # Define C function signatures using ctypes | ||||
| llmodel.llmodel_model_create.argtypes = [ctypes.c_char_p] | ||||
| @@ -90,12 +101,14 @@ PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32) | ||||
| ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p) | ||||
| RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool) | ||||
|  | ||||
| llmodel.llmodel_prompt.argtypes = [ctypes.c_void_p,  | ||||
|                                    ctypes.c_char_p,  | ||||
|                                    PromptCallback, | ||||
|                                    ResponseCallback,  | ||||
|                                    RecalculateCallback,  | ||||
|                                    ctypes.POINTER(LLModelPromptContext)] | ||||
| llmodel.llmodel_prompt.argtypes = [ | ||||
|     ctypes.c_void_p, | ||||
|     ctypes.c_char_p, | ||||
|     PromptCallback, | ||||
|     ResponseCallback, | ||||
|     RecalculateCallback, | ||||
|     ctypes.POINTER(LLModelPromptContext), | ||||
| ] | ||||
|  | ||||
| llmodel.llmodel_prompt.restype = None | ||||
|  | ||||
| @@ -142,7 +155,6 @@ class LLModel: | ||||
|         else: | ||||
|             raise ValueError("Unable to instantiate model") | ||||
|  | ||||
|  | ||||
|     def load_model(self, model_path: str) -> bool: | ||||
|         """ | ||||
|         Load model from a file. | ||||
| @@ -182,21 +194,59 @@ class LLModel: | ||||
|             raise Exception("Model not loaded") | ||||
|         return llmodel.llmodel_threadCount(self.model) | ||||
|  | ||||
|     def prompt_model(self,  | ||||
|                      prompt: str, | ||||
|                      logits_size: int = 0,  | ||||
|                      tokens_size: int = 0,  | ||||
|                      n_past: int = 0,  | ||||
|                      n_ctx: int = 1024,  | ||||
|                      n_predict: int = 128,  | ||||
|                      top_k: int = 40,  | ||||
|                      top_p: float = .9,  | ||||
|                      temp: float = .1,  | ||||
|                      n_batch: int = 8,  | ||||
|                      repeat_penalty: float = 1.2,  | ||||
|                      repeat_last_n: int = 10,  | ||||
|                      context_erase: float = .5, | ||||
|                      streaming: bool = True) -> str: | ||||
|     def _set_context( | ||||
|         self, | ||||
|         n_predict: int = 4096, | ||||
|         top_k: int = 40, | ||||
|         top_p: float = 0.9, | ||||
|         temp: float = 0.1, | ||||
|         n_batch: int = 8, | ||||
|         repeat_penalty: float = 1.2, | ||||
|         repeat_last_n: int = 10, | ||||
|         context_erase: float = 0.75, | ||||
|         reset_context: bool = False, | ||||
|     ): | ||||
|         if self.context is None: | ||||
|             self.context = LLModelPromptContext( | ||||
|                 logits_size=0, | ||||
|                 tokens_size=0, | ||||
|                 n_past=0, | ||||
|                 n_ctx=0, | ||||
|                 n_predict=n_predict, | ||||
|                 top_k=top_k, | ||||
|                 top_p=top_p, | ||||
|                 temp=temp, | ||||
|                 n_batch=n_batch, | ||||
|                 repeat_penalty=repeat_penalty, | ||||
|                 repeat_last_n=repeat_last_n, | ||||
|                 context_erase=context_erase, | ||||
|             ) | ||||
|         elif reset_context: | ||||
|             self.context.n_past = 0 | ||||
|  | ||||
|         self.context.n_predict = n_predict | ||||
|         self.context.top_k = top_k | ||||
|         self.context.top_p = top_p | ||||
|         self.context.temp = temp | ||||
|         self.context.n_batch = n_batch | ||||
|         self.context.repeat_penalty = repeat_penalty | ||||
|         self.context.repeat_last_n = repeat_last_n | ||||
|         self.context.context_erase = context_erase | ||||
|  | ||||
|     def prompt_model( | ||||
|         self, | ||||
|         prompt: str, | ||||
|         n_predict: int = 4096, | ||||
|         top_k: int = 40, | ||||
|         top_p: float = 0.9, | ||||
|         temp: float = 0.1, | ||||
|         n_batch: int = 8, | ||||
|         repeat_penalty: float = 1.2, | ||||
|         repeat_last_n: int = 10, | ||||
|         context_erase: float = 0.75, | ||||
|         reset_context: bool = False, | ||||
|         streaming=False, | ||||
|     ) -> str: | ||||
|         """ | ||||
|         Generate response from model from a prompt. | ||||
|  | ||||
| @@ -211,117 +261,100 @@ class LLModel: | ||||
|         ------- | ||||
|         Model response str | ||||
|         """ | ||||
|          | ||||
|  | ||||
|         prompt = prompt.encode('utf-8') | ||||
|         prompt = ctypes.c_char_p(prompt) | ||||
|  | ||||
|         old_stdout = sys.stdout  | ||||
|         old_stdout = sys.stdout | ||||
|  | ||||
|         stream_processor = DualStreamProcessor() | ||||
|      | ||||
|  | ||||
|         if streaming: | ||||
|             stream_processor.stream = sys.stdout | ||||
|          | ||||
|  | ||||
|         sys.stdout = stream_processor | ||||
|  | ||||
|         self._set_context( | ||||
|             n_predict=n_predict, | ||||
|             top_k=top_k, | ||||
|             top_p=top_p, | ||||
|             temp=temp, | ||||
|             n_batch=n_batch, | ||||
|             repeat_penalty=repeat_penalty, | ||||
|             repeat_last_n=repeat_last_n, | ||||
|             context_erase=context_erase, | ||||
|             reset_context=reset_context, | ||||
|         ) | ||||
|  | ||||
|         if self.context is None: | ||||
|             self.context = LLModelPromptContext( | ||||
|                 logits_size=logits_size,  | ||||
|                 tokens_size=tokens_size,  | ||||
|                 n_past=n_past,  | ||||
|                 n_ctx=n_ctx,  | ||||
|                 n_predict=n_predict,  | ||||
|                 top_k=top_k,  | ||||
|                 top_p=top_p,  | ||||
|                 temp=temp,  | ||||
|                 n_batch=n_batch,  | ||||
|                 repeat_penalty=repeat_penalty,  | ||||
|                 repeat_last_n=repeat_last_n,  | ||||
|                 context_erase=context_erase | ||||
|             ) | ||||
|  | ||||
|         llmodel.llmodel_prompt(self.model,  | ||||
|                                prompt,  | ||||
|                                PromptCallback(self._prompt_callback), | ||||
|                                ResponseCallback(self._response_callback),  | ||||
|                                RecalculateCallback(self._recalculate_callback),  | ||||
|                                self.context) | ||||
|         llmodel.llmodel_prompt( | ||||
|             self.model, | ||||
|             prompt, | ||||
|             PromptCallback(self._prompt_callback), | ||||
|             ResponseCallback(self._response_callback), | ||||
|             RecalculateCallback(self._recalculate_callback), | ||||
|             self.context, | ||||
|         ) | ||||
|  | ||||
|         # Revert to old stdout | ||||
|         sys.stdout = old_stdout | ||||
|         # Force new line | ||||
|         print() | ||||
|         return stream_processor.output | ||||
|  | ||||
|     def generator(self,  | ||||
|                   prompt: str, | ||||
|                   logits_size: int = 0,  | ||||
|                   tokens_size: int = 0,  | ||||
|                   n_past: int = 0,  | ||||
|                   n_ctx: int = 1024,  | ||||
|                   n_predict: int = 128,  | ||||
|                   top_k: int = 40,  | ||||
|                   top_p: float = .9,  | ||||
|                   temp: float = .1,  | ||||
|                   n_batch: int = 8,  | ||||
|                   repeat_penalty: float = 1.2,  | ||||
|                   repeat_last_n: int = 10,  | ||||
|                   context_erase: float = .5) -> str: | ||||
|  | ||||
|     def prompt_model_streaming( | ||||
|         self, | ||||
|         prompt: str, | ||||
|         n_predict: int = 4096, | ||||
|         top_k: int = 40, | ||||
|         top_p: float = 0.9, | ||||
|         temp: float = 0.1, | ||||
|         n_batch: int = 8, | ||||
|         repeat_penalty: float = 1.2, | ||||
|         repeat_last_n: int = 10, | ||||
|         context_erase: float = 0.75, | ||||
|         reset_context: bool = False, | ||||
|     ) -> Iterable: | ||||
|         # Symbol to terminate from generator | ||||
|         TERMINATING_SYMBOL = "#TERMINATE#" | ||||
|          | ||||
|  | ||||
|         output_queue = queue.Queue() | ||||
|  | ||||
|         prompt = prompt.encode('utf-8') | ||||
|         prompt = ctypes.c_char_p(prompt) | ||||
|  | ||||
|         if self.context is None: | ||||
|             self.context = LLModelPromptContext( | ||||
|                 logits_size=logits_size,  | ||||
|                 tokens_size=tokens_size,  | ||||
|                 n_past=n_past,  | ||||
|                 n_ctx=n_ctx,  | ||||
|                 n_predict=n_predict,  | ||||
|                 top_k=top_k,  | ||||
|                 top_p=top_p,  | ||||
|                 temp=temp,  | ||||
|                 n_batch=n_batch,  | ||||
|                 repeat_penalty=repeat_penalty,  | ||||
|                 repeat_last_n=repeat_last_n,  | ||||
|                 context_erase=context_erase | ||||
|             ) | ||||
|         self._set_context( | ||||
|             n_predict=n_predict, | ||||
|             top_k=top_k, | ||||
|             top_p=top_p, | ||||
|             temp=temp, | ||||
|             n_batch=n_batch, | ||||
|             repeat_penalty=repeat_penalty, | ||||
|             repeat_last_n=repeat_last_n, | ||||
|             context_erase=context_erase, | ||||
|             reset_context=reset_context, | ||||
|         ) | ||||
|  | ||||
|         # Put response tokens into an output queue | ||||
|         def _generator_response_callback(token_id, response): | ||||
|             output_queue.put(response.decode('utf-8', 'replace')) | ||||
|             return True | ||||
|  | ||||
|         def run_llmodel_prompt(model,  | ||||
|                                prompt, | ||||
|                                prompt_callback, | ||||
|                                response_callback, | ||||
|                                recalculate_callback, | ||||
|                                context): | ||||
|             llmodel.llmodel_prompt(model,  | ||||
|                                    prompt,  | ||||
|                                    prompt_callback, | ||||
|                                    response_callback,  | ||||
|                                    recalculate_callback,  | ||||
|                                    context) | ||||
|         def run_llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context): | ||||
|             llmodel.llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context) | ||||
|             output_queue.put(TERMINATING_SYMBOL) | ||||
|              | ||||
|  | ||||
|         # Kick off llmodel_prompt in separate thread so we can return generator | ||||
|         # immediately | ||||
|         thread = threading.Thread(target=run_llmodel_prompt, | ||||
|                                   args=(self.model,  | ||||
|                                         prompt,  | ||||
|                                         PromptCallback(self._prompt_callback), | ||||
|                                         ResponseCallback(_generator_response_callback),  | ||||
|                                         RecalculateCallback(self._recalculate_callback),  | ||||
|                                         self.context)) | ||||
|         thread = threading.Thread( | ||||
|             target=run_llmodel_prompt, | ||||
|             args=( | ||||
|                 self.model, | ||||
|                 prompt, | ||||
|                 PromptCallback(self._prompt_callback), | ||||
|                 ResponseCallback(_generator_response_callback), | ||||
|                 RecalculateCallback(self._recalculate_callback), | ||||
|                 self.context, | ||||
|             ), | ||||
|         ) | ||||
|         thread.start() | ||||
|  | ||||
|         # Generator | ||||
|   | ||||
							
								
								
									
										0
									
								
								gpt4all-bindings/python/gpt4all/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								gpt4all-bindings/python/gpt4all/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										54
									
								
								gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,54 @@ | ||||
| import sys | ||||
| from io import StringIO | ||||
|  | ||||
| from gpt4all import GPT4All | ||||
|  | ||||
|  | ||||
| def test_inference(): | ||||
|     model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin') | ||||
|     output_1 = model.generate('hello', top_k=1) | ||||
|  | ||||
|     with model.chat_session(): | ||||
|         response = model.generate(prompt='hello', top_k=1) | ||||
|         response = model.generate(prompt='write me a short poem', top_k=1) | ||||
|         response = model.generate(prompt='thank you', top_k=1) | ||||
|         print(model.current_chat_session) | ||||
|  | ||||
|     output_2 = model.generate('hello', top_k=1) | ||||
|  | ||||
|     assert output_1 == output_2 | ||||
|  | ||||
|     tokens = [] | ||||
|     for token in model.generate('hello', streaming=True): | ||||
|         tokens.append(token) | ||||
|  | ||||
|     assert len(tokens) > 0 | ||||
|  | ||||
|     with model.chat_session(): | ||||
|         try: | ||||
|             response = model.generate(prompt='hello', top_k=1, streaming=True) | ||||
|             assert False | ||||
|         except NotImplementedError: | ||||
|             assert True | ||||
|  | ||||
|  | ||||
| def test_inference_hparams(): | ||||
|     model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin') | ||||
|  | ||||
|     output = model.generate("The capital of france is ", max_tokens=3) | ||||
|     assert 'Paris' in output | ||||
|  | ||||
|  | ||||
| def test_inference_falcon(): | ||||
|     model = GPT4All(model_name='ggml-model-gpt4all-falcon-q4_0.bin') | ||||
|     prompt = 'hello' | ||||
|     output = model.generate(prompt) | ||||
|  | ||||
|     assert len(output) > 0 | ||||
|  | ||||
|  | ||||
| def test_inference_mpt(): | ||||
|     model = GPT4All(model_name='ggml-mpt-7b-chat.bin') | ||||
|     prompt = 'hello' | ||||
|     output = model.generate(prompt) | ||||
|     assert len(output) > 0 | ||||
		Reference in New Issue
	
	Block a user