fix chat model case

This commit is contained in:
Ankush Gola
2023-06-07 18:01:39 -07:00
parent ab7a8e14df
commit 0197abf2b3
2 changed files with 17 additions and 2 deletions

View File

@@ -93,10 +93,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
flattened_outputs = output.flatten()
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
if run_managers:
@@ -143,10 +146,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
flattened_outputs = output.flatten()
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)

View File

@@ -96,6 +96,15 @@ def test_openai_streaming() -> None:
assert isinstance(token["choices"][0]["text"], str)
def test_openai_multiple_prompts() -> None:
"""Test completion with multiple prompts."""
llm = OpenAI(max_tokens=10)
output = llm.generate(["I'm Pickle Rick", "I'm Pickle Rick"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 2
def test_openai_streaming_error() -> None:
"""Test error handling in stream."""
llm = OpenAI(best_of=2)