standard-tests[patch]: fix oai usage metadata test (#27122)

This commit is contained in:
Bagatur 2024-10-04 13:00:48 -07:00 committed by GitHub
parent 827bdf4f51
commit bd5b335cb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 52 additions and 64 deletions

View File

@ -1,7 +1,7 @@
"""Standard LangChain interface tests"""
from pathlib import Path
from typing import List, Literal, Type, cast
from typing import Dict, List, Literal, Type, cast
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
@ -36,16 +36,22 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
@property
def supported_usage_metadata_details(
self,
) -> List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
) -> Dict[
Literal["invoke", "stream"],
List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
],
]:
return ["cache_read_input", "cache_creation_input"]
return {
"invoke": ["cache_read_input", "cache_creation_input"],
"stream": ["cache_read_input", "cache_creation_input"],
}
def invoke_with_cache_creation_input(self, *, stream: bool = False) -> AIMessage:
llm = ChatAnthropic(

View File

@ -238,30 +238,6 @@ async def test_async_chat_openai_bind_functions() -> None:
assert isinstance(generation, AIMessage)
def test_chat_openai_extra_kwargs() -> None:
"""Test extra kwargs to chat openai."""
# Check that foo is saved in extra_kwargs.
llm = ChatOpenAI(foo=3, max_tokens=10) # type: ignore[call-arg]
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}
# Test that if extra_kwargs are provided, they are added to it.
llm = ChatOpenAI(foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
assert llm.model_kwargs == {"foo": 3, "bar": 2}
# Test that if provided twice it errors
with pytest.raises(ValueError):
ChatOpenAI(foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
# Test that if explicit param is specified in kwargs it errors
with pytest.raises(ValueError):
ChatOpenAI(model_kwargs={"temperature": 0.2})
# Test that "model" cannot be specified in kwargs
with pytest.raises(ValueError):
ChatOpenAI(model_kwargs={"model": "gpt-3.5-turbo-instruct"})
@pytest.mark.scheduled
def test_openai_streaming() -> None:
"""Test streaming tokens from OpenAI."""

View File

@ -1,7 +1,7 @@
"""Standard LangChain interface tests"""
from pathlib import Path
from typing import List, Literal, Type, cast
from typing import Dict, List, Literal, Type, cast
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
@ -28,16 +28,19 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
@property
def supported_usage_metadata_details(
self,
) -> List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
) -> Dict[
Literal["invoke", "stream"],
List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
],
]:
return ["reasoning_output", "cache_read_input"]
return {"invoke": ["reasoning_output", "cache_read_input"], "stream": []}
def invoke_with_cache_read_input(self, *, stream: bool = False) -> AIMessage:
with open(REPO_ROOT_DIR / "README.md", "r") as f:

View File

@ -151,25 +151,25 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(result.usage_metadata["output_tokens"], int)
assert isinstance(result.usage_metadata["total_tokens"], int)
if "audio_input" in self.supported_usage_metadata_details:
if "audio_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_audio_input()
assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int) # type: ignore[index]
if "audio_output" in self.supported_usage_metadata_details:
if "audio_output" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_audio_output()
assert isinstance(msg.usage_metadata["output_token_details"]["audio"], int) # type: ignore[index]
if "reasoning_output" in self.supported_usage_metadata_details:
if "reasoning_output" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_reasoning_output()
assert isinstance(
msg.usage_metadata["output_token_details"]["reasoning"], # type: ignore[index]
int,
)
if "cache_read_input" in self.supported_usage_metadata_details:
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_read_input()
assert isinstance(
msg.usage_metadata["input_token_details"]["cache_read"], # type: ignore[index]
int,
)
if "cache_creation_input" in self.supported_usage_metadata_details:
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_creation_input()
assert isinstance(
msg.usage_metadata["input_token_details"]["cache_creation"], # type: ignore[index]
@ -189,25 +189,25 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(full.usage_metadata["output_tokens"], int)
assert isinstance(full.usage_metadata["total_tokens"], int)
if "audio_input" in self.supported_usage_metadata_details:
if "audio_input" in self.supported_usage_metadata_details["stream"]:
msg = self.invoke_with_audio_input(stream=True)
assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int) # type: ignore[index]
if "audio_output" in self.supported_usage_metadata_details:
if "audio_output" in self.supported_usage_metadata_details["stream"]:
msg = self.invoke_with_audio_output(stream=True)
assert isinstance(msg.usage_metadata["output_token_details"]["audio"], int) # type: ignore[index]
if "reasoning_output" in self.supported_usage_metadata_details:
if "reasoning_output" in self.supported_usage_metadata_details["stream"]:
msg = self.invoke_with_reasoning_output(stream=True)
assert isinstance(
msg.usage_metadata["output_token_details"]["reasoning"], # type: ignore[index]
int,
)
if "cache_read_input" in self.supported_usage_metadata_details:
if "cache_read_input" in self.supported_usage_metadata_details["stream"]:
msg = self.invoke_with_cache_read_input(stream=True)
assert isinstance(
msg.usage_metadata["input_token_details"]["cache_read"], # type: ignore[index]
int,
)
if "cache_creation_input" in self.supported_usage_metadata_details:
if "cache_creation_input" in self.supported_usage_metadata_details["stream"]:
msg = self.invoke_with_cache_creation_input(stream=True)
assert isinstance(
msg.usage_metadata["input_token_details"]["cache_creation"], # type: ignore[index]

View File

@ -2,7 +2,7 @@
import os
from abc import abstractmethod
from typing import Any, List, Literal, Optional, Tuple, Type
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from unittest import mock
import pytest
@ -141,16 +141,19 @@ class ChatModelTests(BaseStandardTests):
@property
def supported_usage_metadata_details(
self,
) -> List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
) -> Dict[
Literal["invoke", "stream"],
List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
],
]:
return []
return {"invoke": [], "stream": []}
class ChatModelUnitTests(ChatModelTests):