(ChatOpenAI) Add model_name to LLMResult.llm_output (#1960)

This makes sure OpenAI and ChatOpenAI have the same llm_output, and
allow tracking usage per model. Same work for OpenAI was done in
https://github.com/hwchase17/langchain/pull/1713.
This commit is contained in:
Mario Kostelac
2023-03-24 16:51:16 +01:00
committed by GitHub
parent 6e0d3880df
commit e7d6de6b1c
2 changed files with 31 additions and 13 deletions

View File

@@ -1,5 +1,6 @@
"""Test ChatOpenAI wrapper."""
import pytest
from langchain.callbacks.base import CallbackManager
@@ -78,6 +79,24 @@ def test_chat_openai_streaming() -> None:
assert isinstance(response, BaseMessage)
def test_chat_openai_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatOpenAI(max_tokens=10)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == chat.model_name
def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatOpenAI(max_tokens=10, streaming=True)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == chat.model_name
def test_chat_openai_invalid_streaming_params() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
with pytest.raises(ValueError):