python bindings: typing fixes, misc fixes (#1131)

* python: do not mutate locals()

* python: fix (some) typing complaints

* python: queue sentinel need not be a str

* python: make long inference tests opt in
This commit is contained in:
Aaron Miller 2023-07-03 18:30:24 -07:00 committed by GitHub
parent 01bd3d6802
commit 6987910668
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 20 deletions

View File

@ -5,7 +5,7 @@ import os
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, Union from typing import Dict, Iterable, List, Union, Optional
import requests import requests
from tqdm import tqdm from tqdm import tqdm
@ -22,7 +22,12 @@ class GPT4All:
""" """
def __init__( def __init__(
self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True, n_threads=None self,
model_name: str,
model_path: Optional[str] = None,
model_type: Optional[str] = None,
allow_download: bool = True,
n_threads: Optional[int] = None,
): ):
""" """
Constructor Constructor
@ -60,7 +65,7 @@ class GPT4All:
@staticmethod @staticmethod
def retrieve_model( def retrieve_model(
model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True model_name: str, model_path: Optional[str] = None, allow_download: bool = True, verbose: bool = True
) -> str: ) -> str:
""" """
Find model file, and if it doesn't exist, download the model. Find model file, and if it doesn't exist, download the model.
@ -120,7 +125,7 @@ class GPT4All:
raise ValueError("Failed to retrieve model") raise ValueError("Failed to retrieve model")
@staticmethod @staticmethod
def download_model(model_filename: str, model_path: str, verbose: bool = True, url: str = None) -> str: def download_model(model_filename: str, model_path: str, verbose: bool = True, url: Optional[str] = None) -> str:
""" """
Download model from https://gpt4all.io. Download model from https://gpt4all.io.
@ -181,7 +186,7 @@ class GPT4All:
repeat_penalty: float = 1.18, repeat_penalty: float = 1.18,
repeat_last_n: int = 64, repeat_last_n: int = 64,
n_batch: int = 8, n_batch: int = 8,
n_predict: int = None, n_predict: Optional[int] = None,
streaming: bool = False, streaming: bool = False,
) -> Union[str, Iterable]: ) -> Union[str, Iterable]:
""" """
@ -202,13 +207,16 @@ class GPT4All:
Returns: Returns:
Either the entire completion or a generator that yields the completion token by token. Either the entire completion or a generator that yields the completion token by token.
""" """
generate_kwargs = locals() generate_kwargs = dict(
generate_kwargs.pop('self') prompt=prompt,
generate_kwargs.pop('max_tokens') temp=temp,
generate_kwargs.pop('streaming') top_k=top_k,
generate_kwargs['n_predict'] = max_tokens top_p=top_p,
if n_predict is not None: repeat_penalty=repeat_penalty,
generate_kwargs['n_predict'] = n_predict repeat_last_n=repeat_last_n,
n_batch=n_batch,
n_predict=n_predict if n_predict is not None else max_tokens,
)
if self._is_chat_session_activated: if self._is_chat_session_activated:
self.current_chat_session.append({"role": "user", "content": prompt}) self.current_chat_session.append({"role": "user", "content": prompt})

View File

@ -262,8 +262,8 @@ class LLModel:
Model response str Model response str
""" """
prompt = prompt.encode('utf-8') prompt_bytes = prompt.encode('utf-8')
prompt = ctypes.c_char_p(prompt) prompt_ptr = ctypes.c_char_p(prompt_bytes)
old_stdout = sys.stdout old_stdout = sys.stdout
@ -288,7 +288,7 @@ class LLModel:
llmodel.llmodel_prompt( llmodel.llmodel_prompt(
self.model, self.model,
prompt, prompt_ptr,
PromptCallback(self._prompt_callback), PromptCallback(self._prompt_callback),
ResponseCallback(self._response_callback), ResponseCallback(self._response_callback),
RecalculateCallback(self._recalculate_callback), RecalculateCallback(self._recalculate_callback),
@ -314,12 +314,12 @@ class LLModel:
reset_context: bool = False, reset_context: bool = False,
) -> Iterable: ) -> Iterable:
# Symbol to terminate from generator # Symbol to terminate from generator
TERMINATING_SYMBOL = "#TERMINATE#" TERMINATING_SYMBOL = object()
output_queue = queue.Queue() output_queue = queue.Queue()
prompt = prompt.encode('utf-8') prompt_bytes = prompt.encode('utf-8')
prompt = ctypes.c_char_p(prompt) prompt_ptr = ctypes.c_char_p(prompt_bytes)
self._set_context( self._set_context(
n_predict=n_predict, n_predict=n_predict,
@ -348,7 +348,7 @@ class LLModel:
target=run_llmodel_prompt, target=run_llmodel_prompt,
args=( args=(
self.model, self.model,
prompt, prompt_ptr,
PromptCallback(self._prompt_callback), PromptCallback(self._prompt_callback),
ResponseCallback(_generator_response_callback), ResponseCallback(_generator_response_callback),
RecalculateCallback(self._recalculate_callback), RecalculateCallback(self._recalculate_callback),
@ -360,7 +360,7 @@ class LLModel:
# Generator # Generator
while True: while True:
response = output_queue.get() response = output_queue.get()
if response == TERMINATING_SYMBOL: if response is TERMINATING_SYMBOL:
break break
yield response yield response

View File

@ -25,4 +25,7 @@ isort:
source env/bin/activate; isort --ignore-whitespace --atomic -w 120 gpt4all source env/bin/activate; isort --ignore-whitespace --atomic -w 120 gpt4all
test: test:
source env/bin/activate; pytest -s gpt4all/tests -k "not test_inference_long"
test_all:
source env/bin/activate; pytest -s gpt4all/tests source env/bin/activate; pytest -s gpt4all/tests