diff --git a/gpt4all-bindings/python/docs/gpt4all_python.md b/gpt4all-bindings/python/docs/gpt4all_python.md index 1863e8a2..73f3402b 100644 --- a/gpt4all-bindings/python/docs/gpt4all_python.md +++ b/gpt4all-bindings/python/docs/gpt4all_python.md @@ -64,7 +64,6 @@ Use the GPT4All `chat_session` context manager to hold chat conversations with t } ] ``` - When using GPT4All models in the chat_session context: - The model is given a prompt template which makes it chatty. @@ -79,7 +78,7 @@ When using GPT4All models in the chat_session context: ### Streaming Generations To interact with GPT4All responses as the model generates, use the `streaming = True` flag during generation. -=== "GPT4All Example" +=== "GPT4All Streaming Example" ``` py from gpt4all import GPT4All model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin") @@ -93,4 +92,22 @@ To interact with GPT4All responses as the model generates, use the `streaming = [' Paris', ' is', ' a', ' city', ' that', ' has', ' been', ' a', ' major', ' cultural', ' and', ' economic', ' center', ' for', ' over', ' ', '2', ',', '0', '0'] ``` +#### Streaming and Chat Sessions +When streaming tokens in a chat session, you must manually handle collection and updating of the chat history. + +```python +from gpt4all import GPT4All +model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin") + +with model.chat_session(): + tokens = list(model.generate(prompt='hello', top_k=1, streaming=True)) + model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)}) + + tokens = list(model.generate(prompt='write me a poem about dogs', top_k=1, streaming=True)) + model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)}) + + print(model.current_chat_session) +``` + + ::: gpt4all.gpt4all.GPT4All diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 91a1169c..ca22d18b 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -210,9 +210,6 @@ class GPT4All: if n_predict is not None: generate_kwargs['n_predict'] = n_predict - if streaming and self._is_chat_session_activated: - raise NotImplementedError("Streaming tokens in a chat session is not currently supported.") - if self._is_chat_session_activated: self.current_chat_session.append({"role": "user", "content": prompt}) generate_kwargs['prompt'] = self._format_chat_prompt_template(messages=self.current_chat_session) diff --git a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py index 10ac88fc..df382ed5 100644 --- a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py @@ -25,11 +25,13 @@ def test_inference(): assert len(tokens) > 0 with model.chat_session(): - try: - response = model.generate(prompt='hello', top_k=1, streaming=True) - assert False - except NotImplementedError: - assert True + tokens = list(model.generate(prompt='hello', top_k=1, streaming=True)) + model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)}) + + tokens = list(model.generate(prompt='write me a poem about dogs', top_k=1, streaming=True)) + model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)}) + + print(model.current_chat_session) def do_long_input(model):