From 0197abf2b38cfd2c39f5570349e11dc191b6898d Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Wed, 7 Jun 2023 18:01:39 -0700 Subject: [PATCH] fix chat model case --- langchain/chat_models/base.py | 10 ++++++++-- tests/integration_tests/llms/test_openai.py | 9 +++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 744037944ee..c009f4853a4 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -93,10 +93,13 @@ class BaseChatModel(BaseLanguageModel, ABC): for run_manager in run_managers: run_manager.on_llm_error(e) raise e + flattened_outputs = [ + LLMResult(generations=[res.generations], llm_output=res.llm_output) + for res in results + ] llm_output = self._combine_llm_outputs([res.llm_output for res in results]) generations = [res.generations for res in results] output = LLMResult(generations=generations, llm_output=llm_output) - flattened_outputs = output.flatten() for manager, flattened_output in zip(run_managers, flattened_outputs): manager.on_llm_end(flattened_output) if run_managers: @@ -143,10 +146,13 @@ class BaseChatModel(BaseLanguageModel, ABC): *[run_manager.on_llm_error(e) for run_manager in run_managers] ) raise e + flattened_outputs = [ + LLMResult(generations=[res.generations], llm_output=res.llm_output) + for res in results + ] llm_output = self._combine_llm_outputs([res.llm_output for res in results]) generations = [res.generations for res in results] output = LLMResult(generations=generations, llm_output=llm_output) - flattened_outputs = output.flatten() await asyncio.gather( *[ run_manager.on_llm_end(flattened_output) diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index 7b7253ab357..f1a146da44f 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -96,6 +96,15 @@ def test_openai_streaming() -> None: assert isinstance(token["choices"][0]["text"], str) +def test_openai_multiple_prompts() -> None: + """Test completion with multiple prompts.""" + llm = OpenAI(max_tokens=10) + output = llm.generate(["I'm Pickle Rick", "I'm Pickle Rick"]) + assert isinstance(output, LLMResult) + assert isinstance(output.generations, list) + assert len(output.generations) == 2 + + def test_openai_streaming_error() -> None: """Test error handling in stream.""" llm = OpenAI(best_of=2)