mirror of
				https://github.com/nomic-ai/gpt4all.git
				synced 2025-10-31 05:41:39 +00:00 
			
		
		
		
	Generator in Python Bindings - streaming yields tokens at a time (#895)
* generator method * cleanup * bump version number for clarity * added replace in decode to avoid unicodedecode exception * revert back to _build_prompt
This commit is contained in:
		| @@ -122,7 +122,6 @@ class GPT4All(): | ||||
|         Returns: | ||||
|             Model file destination. | ||||
|         """ | ||||
|  | ||||
|         def get_download_url(model_filename): | ||||
|             if url: | ||||
|                 return url | ||||
| @@ -162,6 +161,8 @@ class GPT4All(): | ||||
|             print("Model downloaded at: ", download_path) | ||||
|         return download_path | ||||
|  | ||||
|     # TODO: this naming is just confusing now and needs to be deprecated now that we have generator | ||||
|     # Need to better consolidate all these different model response methods | ||||
|     def generate(self, prompt: str, streaming: bool = True, **generate_kwargs) -> str: | ||||
|         """ | ||||
|         Surfaced method of running generate without accessing model object. | ||||
| @@ -174,7 +175,21 @@ class GPT4All(): | ||||
|         Returns: | ||||
|             Raw string of generated model response. | ||||
|         """ | ||||
|         return self.model.generate(prompt, streaming=streaming, **generate_kwargs) | ||||
|         return self.model.prompt_model(prompt, streaming=streaming, **generate_kwargs) | ||||
|  | ||||
|     def generator(self, prompt: str, **generate_kwargs) -> str: | ||||
|         """ | ||||
|         Surfaced method of running generate without accessing model object. | ||||
|  | ||||
|         Args: | ||||
|             prompt: Raw string to be passed to model. | ||||
|             streaming: True if want output streamed to stdout. | ||||
|             **generate_kwargs: Optional kwargs to pass to prompt context. | ||||
|          | ||||
|         Returns: | ||||
|             Raw string of generated model response. | ||||
|         """ | ||||
|         return self.model.generator(prompt, **generate_kwargs) | ||||
|  | ||||
|     def chat_completion(self, | ||||
|                         messages: List[Dict], | ||||
| @@ -209,14 +224,13 @@ class GPT4All(): | ||||
|                 "choices": List of message dictionary where "content" is generated response and "role" is set | ||||
|                 as "assistant". Right now, only one choice is returned by model. | ||||
|         """ | ||||
|  | ||||
|         full_prompt = self._build_prompt(messages, | ||||
|                                          default_prompt_header=default_prompt_header, | ||||
|                                          default_prompt_footer=default_prompt_footer) | ||||
|         if verbose: | ||||
|             print(full_prompt) | ||||
|  | ||||
|         response = self.model.generate(full_prompt, streaming=streaming, **generate_kwargs) | ||||
|         response = self.model.prompt_model(full_prompt, streaming=streaming, **generate_kwargs) | ||||
|  | ||||
|         if verbose and not streaming: | ||||
|             print(response) | ||||
| @@ -241,8 +255,23 @@ class GPT4All(): | ||||
|     @staticmethod | ||||
|     def _build_prompt(messages: List[Dict], | ||||
|                       default_prompt_header=True, | ||||
|                       default_prompt_footer=False) -> str: | ||||
|         # Helper method to format messages into prompt. | ||||
|                       default_prompt_footer=True) -> str: | ||||
|         """ | ||||
|         Helper method for buildilng a prompt using template from list of messages. | ||||
|  | ||||
|         Args: | ||||
|             messages:  List of dictionaries. Each dictionary should have a "role" key | ||||
|                 with value of "system", "assistant", or "user" and a "content" key with a | ||||
|                 string value. Messages are organized such that "system" messages are at top of prompt, | ||||
|                 and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as | ||||
|                 "Response: {content}". | ||||
|             default_prompt_header: If True (default), add default prompt header after any system role messages and | ||||
|                 before user/assistant role messages. | ||||
|             default_prompt_footer: If True (default), add default footer at end of prompt. | ||||
|          | ||||
|         Returns: | ||||
|             Formatted prompt. | ||||
|         """ | ||||
|         full_prompt = "" | ||||
|  | ||||
|         for message in messages: | ||||
|   | ||||
| @@ -2,9 +2,11 @@ import pkg_resources | ||||
| import ctypes | ||||
| import os | ||||
| import platform | ||||
| import queue | ||||
| import re | ||||
| import subprocess | ||||
| import sys | ||||
| import threading | ||||
|  | ||||
| class DualStreamProcessor: | ||||
|     def __init__(self, stream=None): | ||||
| @@ -167,21 +169,21 @@ class LLModel: | ||||
|             raise Exception("Model not loaded") | ||||
|         return llmodel.llmodel_threadCount(self.model) | ||||
|  | ||||
|     def generate(self,  | ||||
|                  prompt: str, | ||||
|                  logits_size: int = 0,  | ||||
|                  tokens_size: int = 0,  | ||||
|                  n_past: int = 0,  | ||||
|                  n_ctx: int = 1024,  | ||||
|                  n_predict: int = 128,  | ||||
|                  top_k: int = 40,  | ||||
|                  top_p: float = .9,  | ||||
|                  temp: float = .1,  | ||||
|                  n_batch: int = 8,  | ||||
|                  repeat_penalty: float = 1.2,  | ||||
|                  repeat_last_n: int = 10,  | ||||
|                  context_erase: float = .5, | ||||
|                  streaming: bool = False) -> str: | ||||
|     def prompt_model(self,  | ||||
|                      prompt: str, | ||||
|                      logits_size: int = 0,  | ||||
|                      tokens_size: int = 0,  | ||||
|                      n_past: int = 0,  | ||||
|                      n_ctx: int = 1024,  | ||||
|                      n_predict: int = 128,  | ||||
|                      top_k: int = 40,  | ||||
|                      top_p: float = .9,  | ||||
|                      temp: float = .1,  | ||||
|                      n_batch: int = 8,  | ||||
|                      repeat_penalty: float = 1.2,  | ||||
|                      repeat_last_n: int = 10,  | ||||
|                      context_erase: float = .5, | ||||
|                      streaming: bool = True) -> str: | ||||
|         """ | ||||
|         Generate response from model from a prompt. | ||||
|  | ||||
| @@ -237,6 +239,82 @@ class LLModel: | ||||
|         print() | ||||
|         return stream_processor.output | ||||
|  | ||||
|     def generator(self,  | ||||
|                   prompt: str, | ||||
|                   logits_size: int = 0,  | ||||
|                   tokens_size: int = 0,  | ||||
|                   n_past: int = 0,  | ||||
|                   n_ctx: int = 1024,  | ||||
|                   n_predict: int = 128,  | ||||
|                   top_k: int = 40,  | ||||
|                   top_p: float = .9,  | ||||
|                   temp: float = .1,  | ||||
|                   n_batch: int = 8,  | ||||
|                   repeat_penalty: float = 1.2,  | ||||
|                   repeat_last_n: int = 10,  | ||||
|                   context_erase: float = .5) -> str: | ||||
|  | ||||
|         # Symbol to terminate from generator | ||||
|         TERMINATING_SYMBOL = "#TERMINATE#" | ||||
|          | ||||
|         output_queue = queue.Queue() | ||||
|  | ||||
|         prompt = prompt.encode('utf-8') | ||||
|         prompt = ctypes.c_char_p(prompt) | ||||
|  | ||||
|         context = LLModelPromptContext( | ||||
|             logits_size=logits_size,  | ||||
|             tokens_size=tokens_size,  | ||||
|             n_past=n_past,  | ||||
|             n_ctx=n_ctx,  | ||||
|             n_predict=n_predict,  | ||||
|             top_k=top_k,  | ||||
|             top_p=top_p,  | ||||
|             temp=temp,  | ||||
|             n_batch=n_batch,  | ||||
|             repeat_penalty=repeat_penalty,  | ||||
|             repeat_last_n=repeat_last_n,  | ||||
|             context_erase=context_erase | ||||
|         ) | ||||
|  | ||||
|         # Put response tokens into an output queue | ||||
|         def _generator_response_callback(token_id, response): | ||||
|             output_queue.put(response.decode('utf-8', 'replace')) | ||||
|             return True | ||||
|  | ||||
|         def run_llmodel_prompt(model,  | ||||
|                                prompt, | ||||
|                                prompt_callback, | ||||
|                                response_callback, | ||||
|                                recalculate_callback, | ||||
|                                context): | ||||
|             llmodel.llmodel_prompt(model,  | ||||
|                                    prompt,  | ||||
|                                    prompt_callback, | ||||
|                                    response_callback,  | ||||
|                                    recalculate_callback,  | ||||
|                                    context) | ||||
|             output_queue.put(TERMINATING_SYMBOL) | ||||
|              | ||||
|  | ||||
|         # Kick off llmodel_prompt in separate thread so we can return generator | ||||
|         # immediately | ||||
|         thread = threading.Thread(target=run_llmodel_prompt, | ||||
|                                   args=(self.model,  | ||||
|                                         prompt,  | ||||
|                                         PromptCallback(self._prompt_callback), | ||||
|                                         ResponseCallback(_generator_response_callback),  | ||||
|                                         RecalculateCallback(self._recalculate_callback),  | ||||
|                                         context)) | ||||
|         thread.start() | ||||
|  | ||||
|         # Generator | ||||
|         while True: | ||||
|             response = output_queue.get() | ||||
|             if response == TERMINATING_SYMBOL: | ||||
|                 break | ||||
|             yield response | ||||
|  | ||||
|     # Empty prompt callback | ||||
|     @staticmethod | ||||
|     def _prompt_callback(token_id): | ||||
| @@ -245,7 +323,7 @@ class LLModel: | ||||
|     # Empty response callback method that just prints response to be collected | ||||
|     @staticmethod | ||||
|     def _response_callback(token_id, response): | ||||
|         sys.stdout.write(response.decode('utf-8')) | ||||
|         sys.stdout.write(response.decode('utf-8', 'replace')) | ||||
|         return True | ||||
|  | ||||
|     # Empty recalculate callback | ||||
|   | ||||
		Reference in New Issue
	
	Block a user