From 2eca8240e284c1a4621b5a85870c3022d0abed32 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Mon, 4 Aug 2025 17:57:45 -0400 Subject: [PATCH] final(?) unit tests --- .../ollama/langchain_ollama/_compat.py | 22 +++++++ .../ollama/langchain_ollama/chat_models_v1.py | 34 ---------- .../tests/unit_tests/test_chat_models_v1.py | 66 +++++++++++++++---- 3 files changed, 75 insertions(+), 47 deletions(-) diff --git a/libs/partners/ollama/langchain_ollama/_compat.py b/libs/partners/ollama/langchain_ollama/_compat.py index 8c230bae0f3..d1ca38bfbc6 100644 --- a/libs/partners/ollama/langchain_ollama/_compat.py +++ b/libs/partners/ollama/langchain_ollama/_compat.py @@ -6,6 +6,7 @@ from typing import Any, cast from uuid import uuid4 from langchain_core.messages import content_blocks as types +from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.content_blocks import ( ImageContentBlock, ReasoningContentBlock, @@ -20,6 +21,21 @@ from langchain_core.messages.v1 import SystemMessage as SystemMessageV1 from langchain_core.messages.v1 import ToolMessage as ToolMessageV1 +def _get_usage_metadata_from_response( + response: dict[str, Any], +) -> UsageMetadata | None: + """Extract usage metadata from Ollama response.""" + input_tokens = response.get("prompt_eval_count") + output_tokens = response.get("eval_count") + if input_tokens is not None and output_tokens is not None: + return UsageMetadata( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + ) + return None + + def _convert_from_v1_to_ollama_format(message: MessageV1) -> dict[str, Any]: """Convert v1 message to Ollama API format.""" if isinstance(message, HumanMessageV1): @@ -218,6 +234,7 @@ def _convert_to_v1_from_ollama_format(response: dict[str, Any]) -> AIMessageV1: return AIMessageV1( content=content, response_metadata=response_metadata, + usage_metadata=_get_usage_metadata_from_response(response), ) @@ -280,7 +297,12 @@ def _convert_chunk_to_v1(chunk: dict[str, Any]) -> AIMessageChunkV1: if "eval_duration" in chunk: response_metadata["eval_duration"] = chunk["eval_duration"] # type: ignore[typeddict-unknown-key] + usage_metadata = None + if chunk.get("done") is True: + usage_metadata = _get_usage_metadata_from_response(chunk) + return AIMessageChunkV1( content=content, response_metadata=response_metadata or ResponseMetadata(), + usage_metadata=usage_metadata, ) diff --git a/libs/partners/ollama/langchain_ollama/chat_models_v1.py b/libs/partners/ollama/langchain_ollama/chat_models_v1.py index 43a70874c83..f37c19b01be 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models_v1.py +++ b/libs/partners/ollama/langchain_ollama/chat_models_v1.py @@ -25,7 +25,6 @@ from langchain_core.language_models.v1.chat_models import ( agenerate_from_stream, generate_from_stream, ) -from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.v1 import AIMessage as AIMessageV1 from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 from langchain_core.messages.v1 import MessageV1 @@ -58,21 +57,6 @@ from ._utils import validate_model log = logging.getLogger(__name__) -def _get_usage_metadata_from_response( - response: dict[str, Any], -) -> Optional[UsageMetadata]: - """Extract usage metadata from Ollama response.""" - input_tokens = response.get("prompt_eval_count") - output_tokens = response.get("eval_count") - if input_tokens is not None and output_tokens is not None: - return UsageMetadata( - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, - ) - return None - - def _parse_json_string( json_string: str, *, @@ -578,12 +562,6 @@ class ChatOllamaV1(BaseChatModelV1): chunk = _convert_chunk_to_v1(part) - # Add usage metadata for final chunks - if part.get("done") is True: - usage_metadata = _get_usage_metadata_from_response(part) - if usage_metadata: - chunk.usage_metadata = usage_metadata - if run_manager: text_content = "".join( str(block.get("text", "")) @@ -599,9 +577,6 @@ class ChatOllamaV1(BaseChatModelV1): # Non-streaming case response = self._client.chat(**chat_params) ai_message = _convert_to_v1_from_ollama_format(response) - usage_metadata = _get_usage_metadata_from_response(response) - if usage_metadata: - ai_message.usage_metadata = usage_metadata # Convert to chunk for yielding chunk = AIMessageChunkV1( content=ai_message.content, @@ -637,12 +612,6 @@ class ChatOllamaV1(BaseChatModelV1): chunk = _convert_chunk_to_v1(part) - # Add usage metadata for final chunks - if part.get("done") is True: - usage_metadata = _get_usage_metadata_from_response(part) - if usage_metadata: - chunk.usage_metadata = usage_metadata - if run_manager: text_content = "".join( str(block.get("text", "")) @@ -658,9 +627,6 @@ class ChatOllamaV1(BaseChatModelV1): # Non-streaming case response = await self._async_client.chat(**chat_params) ai_message = _convert_to_v1_from_ollama_format(response) - usage_metadata = _get_usage_metadata_from_response(response) - if usage_metadata: - ai_message.usage_metadata = usage_metadata # Convert to chunk for yielding chunk = AIMessageChunkV1( content=ai_message.content, diff --git a/libs/partners/ollama/tests/unit_tests/test_chat_models_v1.py b/libs/partners/ollama/tests/unit_tests/test_chat_models_v1.py index 89329e205dd..cf844ad554d 100644 --- a/libs/partners/ollama/tests/unit_tests/test_chat_models_v1.py +++ b/libs/partners/ollama/tests/unit_tests/test_chat_models_v1.py @@ -211,23 +211,63 @@ class TestChatOllamaV1(ChatModelV1UnitTests): mock_sync_client_class.return_value = mock_sync_client mock_async_client_class.return_value = mock_async_client - def mock_chat_response(*_args: Any, **_kwargs: Any) -> Iterator[dict[str, Any]]: - return iter( - [ - { - "model": MODEL_NAME, - "created_at": "2024-01-01T00:00:00Z", - "message": {"role": "assistant", "content": "Test response"}, - "done": True, - "done_reason": "stop", - } - ] + def mock_chat_response(*args: Any, **kwargs: Any) -> Iterator[dict[str, Any]]: + # Check request characteristics + request_data = kwargs.get("messages", []) + has_tools = "tools" in kwargs + + # Check if this is a reasoning request + is_reasoning_request = any( + isinstance(msg, dict) + and "Think step by step" in str(msg.get("content", "")) + for msg in request_data ) + # Basic response structure + base_response = { + "model": MODEL_NAME, + "created_at": "2024-01-01T00:00:00Z", + "done": True, + "done_reason": "stop", + "prompt_eval_count": 10, + "eval_count": 20, + } + + # Generate appropriate response based on request type + if has_tools: + # Mock tool call response + base_response["message"] = { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "sample_tool", + "arguments": '{"query": "test"}', + } + } + ], + } + elif is_reasoning_request: + # Mock response with reasoning content block + base_response["message"] = { + "role": "assistant", + "content": "The answer is 4.", + "thinking": "Let me think step by step: 2 + 2 = 4", + } + else: + # Regular text response + base_response["message"] = { + "role": "assistant", + "content": "Test response", + } + + return iter([base_response]) + async def mock_async_chat_iterator( - *_args: Any, **_kwargs: Any + *args: Any, **kwargs: Any ) -> AsyncIterator[dict[str, Any]]: - for item in mock_chat_response(*_args, **_kwargs): + for item in mock_chat_response(*args, **kwargs): yield item mock_sync_client.chat.side_effect = mock_chat_response