mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-02 02:02:12 +00:00
python bindings: typing fixes, misc fixes (#1131)
* python: do not mutate locals() * python: fix (some) typing complaints * python: queue sentinel need not be a str * python: make long inference tests opt in
This commit is contained in:
parent
01bd3d6802
commit
6987910668
@ -5,7 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, Union
|
from typing import Dict, Iterable, List, Union, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -22,7 +22,12 @@ class GPT4All:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True, n_threads=None
|
self,
|
||||||
|
model_name: str,
|
||||||
|
model_path: Optional[str] = None,
|
||||||
|
model_type: Optional[str] = None,
|
||||||
|
allow_download: bool = True,
|
||||||
|
n_threads: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Constructor
|
Constructor
|
||||||
@ -60,7 +65,7 @@ class GPT4All:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def retrieve_model(
|
def retrieve_model(
|
||||||
model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True
|
model_name: str, model_path: Optional[str] = None, allow_download: bool = True, verbose: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Find model file, and if it doesn't exist, download the model.
|
Find model file, and if it doesn't exist, download the model.
|
||||||
@ -120,7 +125,7 @@ class GPT4All:
|
|||||||
raise ValueError("Failed to retrieve model")
|
raise ValueError("Failed to retrieve model")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def download_model(model_filename: str, model_path: str, verbose: bool = True, url: str = None) -> str:
|
def download_model(model_filename: str, model_path: str, verbose: bool = True, url: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Download model from https://gpt4all.io.
|
Download model from https://gpt4all.io.
|
||||||
|
|
||||||
@ -181,7 +186,7 @@ class GPT4All:
|
|||||||
repeat_penalty: float = 1.18,
|
repeat_penalty: float = 1.18,
|
||||||
repeat_last_n: int = 64,
|
repeat_last_n: int = 64,
|
||||||
n_batch: int = 8,
|
n_batch: int = 8,
|
||||||
n_predict: int = None,
|
n_predict: Optional[int] = None,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
) -> Union[str, Iterable]:
|
) -> Union[str, Iterable]:
|
||||||
"""
|
"""
|
||||||
@ -202,13 +207,16 @@ class GPT4All:
|
|||||||
Returns:
|
Returns:
|
||||||
Either the entire completion or a generator that yields the completion token by token.
|
Either the entire completion or a generator that yields the completion token by token.
|
||||||
"""
|
"""
|
||||||
generate_kwargs = locals()
|
generate_kwargs = dict(
|
||||||
generate_kwargs.pop('self')
|
prompt=prompt,
|
||||||
generate_kwargs.pop('max_tokens')
|
temp=temp,
|
||||||
generate_kwargs.pop('streaming')
|
top_k=top_k,
|
||||||
generate_kwargs['n_predict'] = max_tokens
|
top_p=top_p,
|
||||||
if n_predict is not None:
|
repeat_penalty=repeat_penalty,
|
||||||
generate_kwargs['n_predict'] = n_predict
|
repeat_last_n=repeat_last_n,
|
||||||
|
n_batch=n_batch,
|
||||||
|
n_predict=n_predict if n_predict is not None else max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
if self._is_chat_session_activated:
|
if self._is_chat_session_activated:
|
||||||
self.current_chat_session.append({"role": "user", "content": prompt})
|
self.current_chat_session.append({"role": "user", "content": prompt})
|
||||||
|
@ -262,8 +262,8 @@ class LLModel:
|
|||||||
Model response str
|
Model response str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prompt = prompt.encode('utf-8')
|
prompt_bytes = prompt.encode('utf-8')
|
||||||
prompt = ctypes.c_char_p(prompt)
|
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||||
|
|
||||||
old_stdout = sys.stdout
|
old_stdout = sys.stdout
|
||||||
|
|
||||||
@ -288,7 +288,7 @@ class LLModel:
|
|||||||
|
|
||||||
llmodel.llmodel_prompt(
|
llmodel.llmodel_prompt(
|
||||||
self.model,
|
self.model,
|
||||||
prompt,
|
prompt_ptr,
|
||||||
PromptCallback(self._prompt_callback),
|
PromptCallback(self._prompt_callback),
|
||||||
ResponseCallback(self._response_callback),
|
ResponseCallback(self._response_callback),
|
||||||
RecalculateCallback(self._recalculate_callback),
|
RecalculateCallback(self._recalculate_callback),
|
||||||
@ -314,12 +314,12 @@ class LLModel:
|
|||||||
reset_context: bool = False,
|
reset_context: bool = False,
|
||||||
) -> Iterable:
|
) -> Iterable:
|
||||||
# Symbol to terminate from generator
|
# Symbol to terminate from generator
|
||||||
TERMINATING_SYMBOL = "#TERMINATE#"
|
TERMINATING_SYMBOL = object()
|
||||||
|
|
||||||
output_queue = queue.Queue()
|
output_queue = queue.Queue()
|
||||||
|
|
||||||
prompt = prompt.encode('utf-8')
|
prompt_bytes = prompt.encode('utf-8')
|
||||||
prompt = ctypes.c_char_p(prompt)
|
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||||
|
|
||||||
self._set_context(
|
self._set_context(
|
||||||
n_predict=n_predict,
|
n_predict=n_predict,
|
||||||
@ -348,7 +348,7 @@ class LLModel:
|
|||||||
target=run_llmodel_prompt,
|
target=run_llmodel_prompt,
|
||||||
args=(
|
args=(
|
||||||
self.model,
|
self.model,
|
||||||
prompt,
|
prompt_ptr,
|
||||||
PromptCallback(self._prompt_callback),
|
PromptCallback(self._prompt_callback),
|
||||||
ResponseCallback(_generator_response_callback),
|
ResponseCallback(_generator_response_callback),
|
||||||
RecalculateCallback(self._recalculate_callback),
|
RecalculateCallback(self._recalculate_callback),
|
||||||
@ -360,7 +360,7 @@ class LLModel:
|
|||||||
# Generator
|
# Generator
|
||||||
while True:
|
while True:
|
||||||
response = output_queue.get()
|
response = output_queue.get()
|
||||||
if response == TERMINATING_SYMBOL:
|
if response is TERMINATING_SYMBOL:
|
||||||
break
|
break
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
|
@ -25,4 +25,7 @@ isort:
|
|||||||
source env/bin/activate; isort --ignore-whitespace --atomic -w 120 gpt4all
|
source env/bin/activate; isort --ignore-whitespace --atomic -w 120 gpt4all
|
||||||
|
|
||||||
test:
|
test:
|
||||||
|
source env/bin/activate; pytest -s gpt4all/tests -k "not test_inference_long"
|
||||||
|
|
||||||
|
test_all:
|
||||||
source env/bin/activate; pytest -s gpt4all/tests
|
source env/bin/activate; pytest -s gpt4all/tests
|
||||||
|
Loading…
Reference in New Issue
Block a user