fix(ollama): raise error when clients are not initialized (#35185)

## Summary
- When `self._client` is `None` in `_create_chat_stream()`, the method
silently produces an empty generator instead of failing.
- The error only surfaces later as a misleading `"No data received from
Ollama stream"` ValueError, making it difficult to diagnose the actual
root cause (uninitialized client).
- Changed to raise `RuntimeError` immediately with a clear message when
the sync client is not initialized.

## Why this matters
Users who hit this path see a confusing error message that points them
in the wrong direction. An explicit error at the point of failure makes
debugging straightforward.

## Test plan
- [x] Added `test_create_chat_stream_raises_when_client_none`
- [x] Existing tests still pass

> This PR was authored with the help of AI tools.

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Yi Liu
2026-02-13 00:56:53 +08:00
committed by GitHub
parent a50d86c353
commit 19ddd42891
6 changed files with 125 additions and 19 deletions

View File

@@ -934,6 +934,12 @@ class ChatOllama(BaseChatModel):
stop: list[str] | None = None, stop: list[str] | None = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Mapping[str, Any] | str]: ) -> AsyncIterator[Mapping[str, Any] | str]:
if not self._async_client:
msg = (
"Ollama async client is not initialized. "
"Make sure the model was properly constructed."
)
raise RuntimeError(msg)
chat_params = self._chat_params(messages, stop, **kwargs) chat_params = self._chat_params(messages, stop, **kwargs)
if chat_params["stream"]: if chat_params["stream"]:
@@ -948,12 +954,17 @@ class ChatOllama(BaseChatModel):
stop: list[str] | None = None, stop: list[str] | None = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Mapping[str, Any] | str]: ) -> Iterator[Mapping[str, Any] | str]:
if not self._client:
msg = (
"Ollama sync client is not initialized. "
"Make sure the model was properly constructed."
)
raise RuntimeError(msg)
chat_params = self._chat_params(messages, stop, **kwargs) chat_params = self._chat_params(messages, stop, **kwargs)
if chat_params["stream"]: if chat_params["stream"]:
if self._client:
yield from self._client.chat(**chat_params) yield from self._client.chat(**chat_params)
elif self._client: else:
yield self._client.chat(**chat_params) yield self._client.chat(**chat_params)
def _chat_stream_with_aggregation( def _chat_stream_with_aggregation(

View File

@@ -298,10 +298,10 @@ class OllamaEmbeddings(BaseModel, Embeddings):
"""Embed search docs.""" """Embed search docs."""
if not self._client: if not self._client:
msg = ( msg = (
"Ollama client is not initialized. " "Ollama sync client is not initialized. "
"Please ensure Ollama is running and the model is loaded." "Make sure the model was properly constructed."
) )
raise ValueError(msg) raise RuntimeError(msg)
return self._client.embed( return self._client.embed(
self.model, texts, options=self._default_params, keep_alive=self.keep_alive self.model, texts, options=self._default_params, keep_alive=self.keep_alive
)["embeddings"] )["embeddings"]
@@ -314,10 +314,10 @@ class OllamaEmbeddings(BaseModel, Embeddings):
"""Embed search docs.""" """Embed search docs."""
if not self._async_client: if not self._async_client:
msg = ( msg = (
"Ollama client is not initialized. " "Ollama async client is not initialized. "
"Please ensure Ollama is running and the model is loaded." "Make sure the model was properly constructed."
) )
raise ValueError(msg) raise RuntimeError(msg)
return ( return (
await self._async_client.embed( await self._async_client.embed(
self.model, self.model,

View File

@@ -347,7 +347,12 @@ class OllamaLLM(BaseLLM):
stop: list[str] | None = None, stop: list[str] | None = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Mapping[str, Any] | str]: ) -> AsyncIterator[Mapping[str, Any] | str]:
if self._async_client: if not self._async_client:
msg = (
"Ollama async client is not initialized. "
"Make sure the model was properly constructed."
)
raise RuntimeError(msg)
async for part in await self._async_client.generate( async for part in await self._async_client.generate(
**self._generate_params(prompt, stop=stop, **kwargs) **self._generate_params(prompt, stop=stop, **kwargs)
): ):
@@ -359,7 +364,12 @@ class OllamaLLM(BaseLLM):
stop: list[str] | None = None, stop: list[str] | None = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Mapping[str, Any] | str]: ) -> Iterator[Mapping[str, Any] | str]:
if self._client: if not self._client:
msg = (
"Ollama sync client is not initialized. "
"Make sure the model was properly constructed."
)
raise RuntimeError(msg)
yield from self._client.generate( yield from self._client.generate(
**self._generate_params(prompt, stop=stop, **kwargs) **self._generate_params(prompt, stop=stop, **kwargs)
) )

View File

@@ -449,6 +449,42 @@ def test_reasoning_param_passed_to_client() -> None:
assert call_kwargs["think"] is True assert call_kwargs["think"] is True
def test_create_chat_stream_raises_when_client_none() -> None:
"""Test that _create_chat_stream raises RuntimeError when client is None."""
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
mock_client_class.return_value = MagicMock()
llm = ChatOllama(model="test-model")
# Force _client to None to simulate uninitialized state
llm._client = None # type: ignore[assignment]
with pytest.raises(RuntimeError, match="sync client is not initialized"):
list(llm._create_chat_stream([HumanMessage("Hello")]))
async def test_acreate_chat_stream_raises_when_client_none() -> None:
"""Test that _acreate_chat_stream raises RuntimeError when client is None."""
with patch("langchain_ollama.chat_models.AsyncClient") as mock_client_class:
mock_client_class.return_value = MagicMock()
llm = ChatOllama(model="test-model")
# Force _async_client to None to simulate uninitialized state
llm._async_client = None # type: ignore[assignment]
with pytest.raises(RuntimeError, match="async client is not initialized"):
async for _ in llm._acreate_chat_stream([HumanMessage("Hello")]):
pass
def test_invoke_raises_when_client_none() -> None:
"""Test that RuntimeError propagates through the public invoke() API."""
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
mock_client_class.return_value = MagicMock()
llm = ChatOllama(model="test-model")
llm._client = None # type: ignore[assignment]
with pytest.raises(RuntimeError, match="sync client is not initialized"):
llm.invoke([HumanMessage("Hello")])
def test_chat_ollama_ignores_strict_arg() -> None: def test_chat_ollama_ignores_strict_arg() -> None:
"""Test that ChatOllama ignores the 'strict' argument.""" """Test that ChatOllama ignores the 'strict' argument."""
response = [ response = [

View File

@@ -1,7 +1,9 @@
"""Test embedding model integration.""" """Test embedding model integration."""
from typing import Any from typing import Any
from unittest.mock import Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest
from langchain_ollama.embeddings import OllamaEmbeddings from langchain_ollama.embeddings import OllamaEmbeddings
@@ -50,3 +52,25 @@ def test_embed_documents_passes_options(mock_client_class: Any) -> None:
options = call_args.kwargs["options"] options = call_args.kwargs["options"]
assert options["num_gpu"] == 4 assert options["num_gpu"] == 4
assert options["temperature"] == 0.5 assert options["temperature"] == 0.5
def test_embed_documents_raises_when_client_none() -> None:
"""Test that embed_documents raises RuntimeError when client is None."""
with patch("langchain_ollama.embeddings.Client") as mock_client_class:
mock_client_class.return_value = MagicMock()
embeddings = OllamaEmbeddings(model="test-model")
embeddings._client = None # type: ignore[assignment]
with pytest.raises(RuntimeError, match="sync client is not initialized"):
embeddings.embed_documents(["test"])
async def test_aembed_documents_raises_when_client_none() -> None:
"""Test that aembed_documents raises RuntimeError when async client is None."""
with patch("langchain_ollama.embeddings.AsyncClient") as mock_client_class:
mock_client_class.return_value = MagicMock()
embeddings = OllamaEmbeddings(model="test-model")
embeddings._async_client = None # type: ignore[assignment]
with pytest.raises(RuntimeError, match="async client is not initialized"):
await embeddings.aembed_documents(["test"])

View File

@@ -1,7 +1,9 @@
"""Test Ollama Chat API wrapper.""" """Test Ollama Chat API wrapper."""
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import MagicMock, patch
import pytest
from langchain_ollama import OllamaLLM from langchain_ollama import OllamaLLM
@@ -65,3 +67,26 @@ def test_reasoning_aggregation() -> None:
result.generations[0][0].generation_info["thinking"] result.generations[0][0].generation_info["thinking"]
== "I am thinking. Still thinking." == "I am thinking. Still thinking."
) )
def test_create_generate_stream_raises_when_client_none() -> None:
"""Test that _create_generate_stream raises RuntimeError when client is None."""
with patch("langchain_ollama.llms.Client") as mock_client_class:
mock_client_class.return_value = MagicMock()
llm = OllamaLLM(model="test-model")
llm._client = None # type: ignore[assignment]
with pytest.raises(RuntimeError, match="sync client is not initialized"):
list(llm._create_generate_stream("Hello"))
async def test_acreate_generate_stream_raises_when_client_none() -> None:
"""Test that _acreate_generate_stream raises RuntimeError when client is None."""
with patch("langchain_ollama.llms.AsyncClient") as mock_client_class:
mock_client_class.return_value = MagicMock()
llm = OllamaLLM(model="test-model")
llm._async_client = None # type: ignore[assignment]
with pytest.raises(RuntimeError, match="async client is not initialized"):
async for _ in llm._acreate_generate_stream("Hello"):
pass