diff --git a/libs/partners/ollama/langchain_ollama/chat_models_v1.py b/libs/partners/ollama/langchain_ollama/chat_models_v1.py index 02fb01de251..43a70874c83 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models_v1.py +++ b/libs/partners/ollama/langchain_ollama/chat_models_v1.py @@ -621,7 +621,7 @@ class ChatOllamaV1(BaseChatModelV1): chat_params = self._chat_params(messages, stop, **kwargs) if chat_params["stream"]: - async for part in await self._async_client.chat(**chat_params): + async for part in self._async_client.chat(**chat_params): # type: ignore[attr-defined] if not isinstance(part, str): # Skip empty load responses if ( diff --git a/libs/partners/ollama/tests/unit_tests/test_chat_models_v1.py b/libs/partners/ollama/tests/unit_tests/test_chat_models_v1.py index e0350d0211b..b36b73aeba3 100644 --- a/libs/partners/ollama/tests/unit_tests/test_chat_models_v1.py +++ b/libs/partners/ollama/tests/unit_tests/test_chat_models_v1.py @@ -2,6 +2,7 @@ import json import logging +from collections.abc import AsyncIterator, Iterator from typing import Any from unittest.mock import MagicMock, patch @@ -160,52 +161,115 @@ class TestChatOllamaV1(ChatModelV1UnitTests): def chat_model_params(self) -> dict: return {"model": MODEL_NAME} + @property + def supports_non_standard_blocks(self) -> bool: + """Override to indicate Ollama doesn't support non-standard content blocks.""" + return False + + @pytest.fixture + def model(self) -> Iterator[ChatOllamaV1]: + """Create a ChatOllamaV1 instance for testing.""" + sync_patcher = patch("langchain_ollama.chat_models_v1.Client") + async_patcher = patch("langchain_ollama.chat_models_v1.AsyncClient") + + mock_sync_client_class = sync_patcher.start() + mock_async_client_class = async_patcher.start() + + mock_sync_client = MagicMock() + mock_async_client = MagicMock() + + mock_sync_client_class.return_value = mock_sync_client + mock_async_client_class.return_value = mock_async_client + + def mock_chat_response(*_args: Any, **_kwargs: Any) -> Iterator[dict[str, Any]]: + return iter( + [ + { + "model": MODEL_NAME, + "created_at": "2024-01-01T00:00:00Z", + "message": {"role": "assistant", "content": "Test response"}, + "done": True, + "done_reason": "stop", + } + ] + ) + + async def mock_async_chat_iterator( + *_args: Any, **_kwargs: Any + ) -> AsyncIterator[dict[str, Any]]: + for item in mock_chat_response(*_args, **_kwargs): + yield item + + mock_sync_client.chat.side_effect = mock_chat_response + mock_async_client.chat.side_effect = mock_async_chat_iterator + + model_instance = self.chat_model_class(**self.chat_model_params) + yield model_instance + sync_patcher.stop() + async_patcher.stop() + def test_initialization(self) -> None: """Test `ChatOllamaV1` initialization.""" - llm = ChatOllamaV1(model=MODEL_NAME) + with ( + patch("langchain_ollama.chat_models_v1.Client"), + patch("langchain_ollama.chat_models_v1.AsyncClient"), + ): + llm = ChatOllamaV1(model=MODEL_NAME) - assert llm.model == MODEL_NAME - assert llm._llm_type == "chat-ollama-v1" + assert llm.model == MODEL_NAME + assert llm._llm_type == "chat-ollama-v1" def test_chat_params(self) -> None: """Test `_chat_params()`.""" - llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.7) + with ( + patch("langchain_ollama.chat_models_v1.Client"), + patch("langchain_ollama.chat_models_v1.AsyncClient"), + ): + llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.7) - messages: list[MessageV1] = [HumanMessageV1("Hello")] + messages: list[MessageV1] = [HumanMessageV1("Hello")] - params = llm._chat_params(messages) + params = llm._chat_params(messages) - assert params["model"] == MODEL_NAME - assert len(params["messages"]) == 1 - assert params["messages"][0]["role"] == "user" - assert params["messages"][0]["content"] == "Hello" + assert params["model"] == MODEL_NAME + assert len(params["messages"]) == 1 + assert params["messages"][0]["role"] == "user" + assert params["messages"][0]["content"] == "Hello" - # Ensure options carry over - assert params["options"].temperature == 0.7 + # Ensure options carry over + assert params["options"].temperature == 0.7 def test_ls_params(self) -> None: """Test LangSmith parameters.""" - llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.5) + with ( + patch("langchain_ollama.chat_models_v1.Client"), + patch("langchain_ollama.chat_models_v1.AsyncClient"), + ): + llm = ChatOllamaV1(model=MODEL_NAME, temperature=0.5) - ls_params = llm._get_ls_params() + ls_params = llm._get_ls_params() - assert ls_params.get("ls_provider") == "ollama" - assert ls_params.get("ls_model_name") == MODEL_NAME - assert ls_params.get("ls_model_type") == "chat" - assert ls_params.get("ls_temperature") == 0.5 + assert ls_params.get("ls_provider") == "ollama" + assert ls_params.get("ls_model_name") == MODEL_NAME + assert ls_params.get("ls_model_type") == "chat" + assert ls_params.get("ls_temperature") == 0.5 def test_bind_tools_basic(self) -> None: """Test basic tool binding functionality.""" - llm = ChatOllamaV1(model=MODEL_NAME) + with ( + patch("langchain_ollama.chat_models_v1.Client"), + patch("langchain_ollama.chat_models_v1.AsyncClient"), + ): + llm = ChatOllamaV1(model=MODEL_NAME) - def test_tool(query: str) -> str: - """A test tool.""" - return f"Result for: {query}" + def test_tool(query: str) -> str: + """A test tool.""" + return f"Result for: {query}" - bound_llm = llm.bind_tools([test_tool]) + bound_llm = llm.bind_tools([test_tool]) - # Should return a bound model - assert bound_llm is not None + # Should return a bound model + assert bound_llm is not None # Missing: `test_arbitrary_roles_accepted_in_chatmessages` @@ -213,9 +277,15 @@ class TestChatOllamaV1(ChatModelV1UnitTests): # But can be added if needed in the future. -@patch("langchain_ollama.chat_models.validate_model") -def test_validate_model_on_init(mock_validate_model: Any) -> None: +@patch("langchain_ollama.chat_models_v1.validate_model") +@patch("langchain_ollama.chat_models_v1.Client") +def test_validate_model_on_init( + mock_client_class: Any, mock_validate_model: Any +) -> None: """Test that the model is validated on initialization when requested.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + # Test that validate_model is called when validate_model_on_init=True ChatOllamaV1(model=MODEL_NAME, validate_model_on_init=True) mock_validate_model.assert_called_once() @@ -314,16 +384,16 @@ def test_load_response_with_empty_content_is_skipped( } ] - with patch("langchain_ollama.chat_models.Client") as mock_client_class: + with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client - mock_client.chat.return_value = load_only_response + mock_client.chat.return_value = iter(load_only_response) llm = ChatOllamaV1(model="test-model") with ( caplog.at_level(logging.WARNING), - pytest.raises(ValueError, match="No data received from Ollama stream"), + pytest.raises(ValueError, match="No generations found in stream"), ): llm.invoke([HumanMessageV1("Hello")]) @@ -344,16 +414,16 @@ def test_load_response_with_whitespace_content_is_skipped( } ] - with patch("langchain_ollama.chat_models.Client") as mock_client_class: + with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client - mock_client.chat.return_value = load_whitespace_response + mock_client.chat.return_value = iter(load_whitespace_response) llm = ChatOllamaV1(model="test-model") with ( caplog.at_level(logging.WARNING), - pytest.raises(ValueError, match="No data received from Ollama stream"), + pytest.raises(ValueError, match="No generations found in stream"), ): llm.invoke([HumanMessageV1("Hello")]) assert "Ollama returned empty response with done_reason='load'" in caplog.text @@ -383,10 +453,10 @@ def test_load_followed_by_content_response( }, ] - with patch("langchain_ollama.chat_models.Client") as mock_client_class: + with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client - mock_client.chat.return_value = load_then_content_response + mock_client.chat.return_value = iter(load_then_content_response) llm = ChatOllamaV1(model="test-model") @@ -394,7 +464,8 @@ def test_load_followed_by_content_response( result = llm.invoke([HumanMessageV1("Hello")]) assert "Ollama returned empty response with done_reason='load'" in caplog.text - assert result.content == "Hello! How can I help you today?" + assert len(result.content) == 1 + assert result.text == "Hello! How can I help you today?" assert result.response_metadata.get("done_reason") == "stop" @@ -412,16 +483,17 @@ def test_load_response_with_actual_content_is_not_skipped( } ] - with patch("langchain_ollama.chat_models.Client") as mock_client_class: + with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client - mock_client.chat.return_value = load_with_content_response + mock_client.chat.return_value = iter(load_with_content_response) llm = ChatOllamaV1(model="test-model") with caplog.at_level(logging.WARNING): result = llm.invoke([HumanMessageV1("Hello")]) - assert result.content == "This is actual content" + assert len(result.content) == 1 + assert result.text == "This is actual content" assert result.response_metadata.get("done_reason") == "load" assert not caplog.text diff --git a/libs/standard-tests/langchain_tests/unit_tests/chat_models_v1.py b/libs/standard-tests/langchain_tests/unit_tests/chat_models_v1.py index bd7a624ba87..a51a963a757 100644 --- a/libs/standard-tests/langchain_tests/unit_tests/chat_models_v1.py +++ b/libs/standard-tests/langchain_tests/unit_tests/chat_models_v1.py @@ -207,38 +207,6 @@ class ChatModelV1UnitTests(ChatModelV1Tests): params = model._identifying_params assert isinstance(params, dict) # Should be dict-like mapping - # Token Counting Tests - def test_get_token_ids(self, model: BaseChatModelV1) -> None: - """Test that ``get_token_ids`` returns a list of integers.""" - text = "Hello, world!" - token_ids = model.get_token_ids(text) - assert isinstance(token_ids, list) - assert all(isinstance(token_id, int) for token_id in token_ids) - assert len(token_ids) > 0 - - def test_get_num_tokens(self, model: BaseChatModelV1) -> None: - """Test that ``get_num_tokens`` returns a positive integer.""" - text = "Hello, world!" - num_tokens = model.get_num_tokens(text) - assert isinstance(num_tokens, int) - assert num_tokens > 0 - - def test_get_num_tokens_from_messages(self, model: BaseChatModelV1) -> None: - """Test that ``get_num_tokens_from_messages`` returns a positive integer.""" - messages = [HumanMessage("Hello, world!")] - num_tokens = model.get_num_tokens_from_messages(messages) # type: ignore[arg-type] - assert isinstance(num_tokens, int) - assert num_tokens > 0 - - def test_token_counting_consistency(self, model: BaseChatModelV1) -> None: - """Test that token counting methods are consistent with each other.""" - text = "Hello, world!" - token_ids = model.get_token_ids(text) - num_tokens = model.get_num_tokens(text) - - # Number of tokens should match length of token IDs list - assert len(token_ids) == num_tokens - # Serialization Tests def test_dump_serialization(self, model: BaseChatModelV1) -> None: """Test that ``dump()`` returns proper serialization."""