From f001cc15cd695c3ff2883b5a78f1a1bc67d1f291 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Thu, 12 Feb 2026 11:54:21 -0500 Subject: [PATCH] cr --- .../ollama/langchain_ollama/chat_models.py | 6 +++++ .../ollama/langchain_ollama/embeddings.py | 12 ++++----- libs/partners/ollama/langchain_ollama/llms.py | 26 ++++++++++++------ .../tests/unit_tests/test_chat_models.py | 24 +++++++++++++++++ .../tests/unit_tests/test_embeddings.py | 26 +++++++++++++++++- .../ollama/tests/unit_tests/test_llms.py | 27 ++++++++++++++++++- 6 files changed, 105 insertions(+), 16 deletions(-) diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index 8e21f9e2703..d47fafc79a2 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -934,6 +934,12 @@ class ChatOllama(BaseChatModel): stop: list[str] | None = None, **kwargs: Any, ) -> AsyncIterator[Mapping[str, Any] | str]: + if not self._async_client: + msg = ( + "Ollama async client is not initialized. " + "Make sure the model was properly constructed." + ) + raise RuntimeError(msg) chat_params = self._chat_params(messages, stop, **kwargs) if chat_params["stream"]: diff --git a/libs/partners/ollama/langchain_ollama/embeddings.py b/libs/partners/ollama/langchain_ollama/embeddings.py index 06bde6f1f53..b8fa44ed41e 100644 --- a/libs/partners/ollama/langchain_ollama/embeddings.py +++ b/libs/partners/ollama/langchain_ollama/embeddings.py @@ -298,10 +298,10 @@ class OllamaEmbeddings(BaseModel, Embeddings): """Embed search docs.""" if not self._client: msg = ( - "Ollama client is not initialized. " - "Please ensure Ollama is running and the model is loaded." + "Ollama sync client is not initialized. " + "Make sure the model was properly constructed." ) - raise ValueError(msg) + raise RuntimeError(msg) return self._client.embed( self.model, texts, options=self._default_params, keep_alive=self.keep_alive )["embeddings"] @@ -314,10 +314,10 @@ class OllamaEmbeddings(BaseModel, Embeddings): """Embed search docs.""" if not self._async_client: msg = ( - "Ollama client is not initialized. " - "Please ensure Ollama is running and the model is loaded." + "Ollama async client is not initialized. " + "Make sure the model was properly constructed." ) - raise ValueError(msg) + raise RuntimeError(msg) return ( await self._async_client.embed( self.model, diff --git a/libs/partners/ollama/langchain_ollama/llms.py b/libs/partners/ollama/langchain_ollama/llms.py index 434839877f4..9e3334ce7a4 100644 --- a/libs/partners/ollama/langchain_ollama/llms.py +++ b/libs/partners/ollama/langchain_ollama/llms.py @@ -347,11 +347,16 @@ class OllamaLLM(BaseLLM): stop: list[str] | None = None, **kwargs: Any, ) -> AsyncIterator[Mapping[str, Any] | str]: - if self._async_client: - async for part in await self._async_client.generate( - **self._generate_params(prompt, stop=stop, **kwargs) - ): - yield part + if not self._async_client: + msg = ( + "Ollama async client is not initialized. " + "Make sure the model was properly constructed." + ) + raise RuntimeError(msg) + async for part in await self._async_client.generate( + **self._generate_params(prompt, stop=stop, **kwargs) + ): + yield part def _create_generate_stream( self, @@ -359,10 +364,15 @@ class OllamaLLM(BaseLLM): stop: list[str] | None = None, **kwargs: Any, ) -> Iterator[Mapping[str, Any] | str]: - if self._client: - yield from self._client.generate( - **self._generate_params(prompt, stop=stop, **kwargs) + if not self._client: + msg = ( + "Ollama sync client is not initialized. " + "Make sure the model was properly constructed." ) + raise RuntimeError(msg) + yield from self._client.generate( + **self._generate_params(prompt, stop=stop, **kwargs) + ) async def _astream_with_aggregation( self, diff --git a/libs/partners/ollama/tests/unit_tests/test_chat_models.py b/libs/partners/ollama/tests/unit_tests/test_chat_models.py index 4e8bd256983..f6f97d3284f 100644 --- a/libs/partners/ollama/tests/unit_tests/test_chat_models.py +++ b/libs/partners/ollama/tests/unit_tests/test_chat_models.py @@ -461,6 +461,30 @@ def test_create_chat_stream_raises_when_client_none() -> None: list(llm._create_chat_stream([HumanMessage("Hello")])) +async def test_acreate_chat_stream_raises_when_client_none() -> None: + """Test that _acreate_chat_stream raises RuntimeError when client is None.""" + with patch("langchain_ollama.chat_models.AsyncClient") as mock_client_class: + mock_client_class.return_value = MagicMock() + llm = ChatOllama(model="test-model") + # Force _async_client to None to simulate uninitialized state + llm._async_client = None # type: ignore[assignment] + + with pytest.raises(RuntimeError, match="async client is not initialized"): + async for _ in llm._acreate_chat_stream([HumanMessage("Hello")]): + pass + + +def test_invoke_raises_when_client_none() -> None: + """Test that RuntimeError propagates through the public invoke() API.""" + with patch("langchain_ollama.chat_models.Client") as mock_client_class: + mock_client_class.return_value = MagicMock() + llm = ChatOllama(model="test-model") + llm._client = None # type: ignore[assignment] + + with pytest.raises(RuntimeError, match="sync client is not initialized"): + llm.invoke([HumanMessage("Hello")]) + + def test_chat_ollama_ignores_strict_arg() -> None: """Test that ChatOllama ignores the 'strict' argument.""" response = [ diff --git a/libs/partners/ollama/tests/unit_tests/test_embeddings.py b/libs/partners/ollama/tests/unit_tests/test_embeddings.py index 1c7c0939cc2..c6b45700110 100644 --- a/libs/partners/ollama/tests/unit_tests/test_embeddings.py +++ b/libs/partners/ollama/tests/unit_tests/test_embeddings.py @@ -1,7 +1,9 @@ """Test embedding model integration.""" from typing import Any -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch + +import pytest from langchain_ollama.embeddings import OllamaEmbeddings @@ -50,3 +52,25 @@ def test_embed_documents_passes_options(mock_client_class: Any) -> None: options = call_args.kwargs["options"] assert options["num_gpu"] == 4 assert options["temperature"] == 0.5 + + +def test_embed_documents_raises_when_client_none() -> None: + """Test that embed_documents raises RuntimeError when client is None.""" + with patch("langchain_ollama.embeddings.Client") as mock_client_class: + mock_client_class.return_value = MagicMock() + embeddings = OllamaEmbeddings(model="test-model") + embeddings._client = None # type: ignore[assignment] + + with pytest.raises(RuntimeError, match="sync client is not initialized"): + embeddings.embed_documents(["test"]) + + +async def test_aembed_documents_raises_when_client_none() -> None: + """Test that aembed_documents raises RuntimeError when async client is None.""" + with patch("langchain_ollama.embeddings.AsyncClient") as mock_client_class: + mock_client_class.return_value = MagicMock() + embeddings = OllamaEmbeddings(model="test-model") + embeddings._async_client = None # type: ignore[assignment] + + with pytest.raises(RuntimeError, match="async client is not initialized"): + await embeddings.aembed_documents(["test"]) diff --git a/libs/partners/ollama/tests/unit_tests/test_llms.py b/libs/partners/ollama/tests/unit_tests/test_llms.py index ab49c591a50..d561d3e0da0 100644 --- a/libs/partners/ollama/tests/unit_tests/test_llms.py +++ b/libs/partners/ollama/tests/unit_tests/test_llms.py @@ -1,7 +1,9 @@ """Test Ollama Chat API wrapper.""" from typing import Any -from unittest.mock import patch +from unittest.mock import MagicMock, patch + +import pytest from langchain_ollama import OllamaLLM @@ -65,3 +67,26 @@ def test_reasoning_aggregation() -> None: result.generations[0][0].generation_info["thinking"] == "I am thinking. Still thinking." ) + + +def test_create_generate_stream_raises_when_client_none() -> None: + """Test that _create_generate_stream raises RuntimeError when client is None.""" + with patch("langchain_ollama.llms.Client") as mock_client_class: + mock_client_class.return_value = MagicMock() + llm = OllamaLLM(model="test-model") + llm._client = None # type: ignore[assignment] + + with pytest.raises(RuntimeError, match="sync client is not initialized"): + list(llm._create_generate_stream("Hello")) + + +async def test_acreate_generate_stream_raises_when_client_none() -> None: + """Test that _acreate_generate_stream raises RuntimeError when client is None.""" + with patch("langchain_ollama.llms.AsyncClient") as mock_client_class: + mock_client_class.return_value = MagicMock() + llm = OllamaLLM(model="test-model") + llm._async_client = None # type: ignore[assignment] + + with pytest.raises(RuntimeError, match="async client is not initialized"): + async for _ in llm._acreate_generate_stream("Hello"): + pass