final(?) unit tests

This commit is contained in:
Mason Daugherty 2025-08-04 17:57:45 -04:00
parent c9a38be6df
commit 2eca8240e2
No known key found for this signature in database
3 changed files with 75 additions and 47 deletions

View File

@ -6,6 +6,7 @@ from typing import Any, cast
from uuid import uuid4
from langchain_core.messages import content_blocks as types
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.content_blocks import (
ImageContentBlock,
ReasoningContentBlock,
@ -20,6 +21,21 @@ from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
def _get_usage_metadata_from_response(
response: dict[str, Any],
) -> UsageMetadata | None:
"""Extract usage metadata from Ollama response."""
input_tokens = response.get("prompt_eval_count")
output_tokens = response.get("eval_count")
if input_tokens is not None and output_tokens is not None:
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
)
return None
def _convert_from_v1_to_ollama_format(message: MessageV1) -> dict[str, Any]:
"""Convert v1 message to Ollama API format."""
if isinstance(message, HumanMessageV1):
@ -218,6 +234,7 @@ def _convert_to_v1_from_ollama_format(response: dict[str, Any]) -> AIMessageV1:
return AIMessageV1(
content=content,
response_metadata=response_metadata,
usage_metadata=_get_usage_metadata_from_response(response),
)
@ -280,7 +297,12 @@ def _convert_chunk_to_v1(chunk: dict[str, Any]) -> AIMessageChunkV1:
if "eval_duration" in chunk:
response_metadata["eval_duration"] = chunk["eval_duration"] # type: ignore[typeddict-unknown-key]
usage_metadata = None
if chunk.get("done") is True:
usage_metadata = _get_usage_metadata_from_response(chunk)
return AIMessageChunkV1(
content=content,
response_metadata=response_metadata or ResponseMetadata(),
usage_metadata=usage_metadata,
)

View File

@ -25,7 +25,6 @@ from langchain_core.language_models.v1.chat_models import (
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.v1 import AIMessage as AIMessageV1
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
from langchain_core.messages.v1 import MessageV1
@ -58,21 +57,6 @@ from ._utils import validate_model
log = logging.getLogger(__name__)
def _get_usage_metadata_from_response(
response: dict[str, Any],
) -> Optional[UsageMetadata]:
"""Extract usage metadata from Ollama response."""
input_tokens = response.get("prompt_eval_count")
output_tokens = response.get("eval_count")
if input_tokens is not None and output_tokens is not None:
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
)
return None
def _parse_json_string(
json_string: str,
*,
@ -578,12 +562,6 @@ class ChatOllamaV1(BaseChatModelV1):
chunk = _convert_chunk_to_v1(part)
# Add usage metadata for final chunks
if part.get("done") is True:
usage_metadata = _get_usage_metadata_from_response(part)
if usage_metadata:
chunk.usage_metadata = usage_metadata
if run_manager:
text_content = "".join(
str(block.get("text", ""))
@ -599,9 +577,6 @@ class ChatOllamaV1(BaseChatModelV1):
# Non-streaming case
response = self._client.chat(**chat_params)
ai_message = _convert_to_v1_from_ollama_format(response)
usage_metadata = _get_usage_metadata_from_response(response)
if usage_metadata:
ai_message.usage_metadata = usage_metadata
# Convert to chunk for yielding
chunk = AIMessageChunkV1(
content=ai_message.content,
@ -637,12 +612,6 @@ class ChatOllamaV1(BaseChatModelV1):
chunk = _convert_chunk_to_v1(part)
# Add usage metadata for final chunks
if part.get("done") is True:
usage_metadata = _get_usage_metadata_from_response(part)
if usage_metadata:
chunk.usage_metadata = usage_metadata
if run_manager:
text_content = "".join(
str(block.get("text", ""))
@ -658,9 +627,6 @@ class ChatOllamaV1(BaseChatModelV1):
# Non-streaming case
response = await self._async_client.chat(**chat_params)
ai_message = _convert_to_v1_from_ollama_format(response)
usage_metadata = _get_usage_metadata_from_response(response)
if usage_metadata:
ai_message.usage_metadata = usage_metadata
# Convert to chunk for yielding
chunk = AIMessageChunkV1(
content=ai_message.content,

View File

@ -211,23 +211,63 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
mock_sync_client_class.return_value = mock_sync_client
mock_async_client_class.return_value = mock_async_client
def mock_chat_response(*_args: Any, **_kwargs: Any) -> Iterator[dict[str, Any]]:
return iter(
[
{
"model": MODEL_NAME,
"created_at": "2024-01-01T00:00:00Z",
"message": {"role": "assistant", "content": "Test response"},
"done": True,
"done_reason": "stop",
}
]
def mock_chat_response(*args: Any, **kwargs: Any) -> Iterator[dict[str, Any]]:
# Check request characteristics
request_data = kwargs.get("messages", [])
has_tools = "tools" in kwargs
# Check if this is a reasoning request
is_reasoning_request = any(
isinstance(msg, dict)
and "Think step by step" in str(msg.get("content", ""))
for msg in request_data
)
# Basic response structure
base_response = {
"model": MODEL_NAME,
"created_at": "2024-01-01T00:00:00Z",
"done": True,
"done_reason": "stop",
"prompt_eval_count": 10,
"eval_count": 20,
}
# Generate appropriate response based on request type
if has_tools:
# Mock tool call response
base_response["message"] = {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "sample_tool",
"arguments": '{"query": "test"}',
}
}
],
}
elif is_reasoning_request:
# Mock response with reasoning content block
base_response["message"] = {
"role": "assistant",
"content": "The answer is 4.",
"thinking": "Let me think step by step: 2 + 2 = 4",
}
else:
# Regular text response
base_response["message"] = {
"role": "assistant",
"content": "Test response",
}
return iter([base_response])
async def mock_async_chat_iterator(
*_args: Any, **_kwargs: Any
*args: Any, **kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
for item in mock_chat_response(*_args, **_kwargs):
for item in mock_chat_response(*args, **kwargs):
yield item
mock_sync_client.chat.side_effect = mock_chat_response