mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(ollama): raise error when clients are not initialized (#35185)
## Summary - When `self._client` is `None` in `_create_chat_stream()`, the method silently produces an empty generator instead of failing. - The error only surfaces later as a misleading `"No data received from Ollama stream"` ValueError, making it difficult to diagnose the actual root cause (uninitialized client). - Changed to raise `RuntimeError` immediately with a clear message when the sync client is not initialized. ## Why this matters Users who hit this path see a confusing error message that points them in the wrong direction. An explicit error at the point of failure makes debugging straightforward. ## Test plan - [x] Added `test_create_chat_stream_raises_when_client_none` - [x] Existing tests still pass > This PR was authored with the help of AI tools. --------- Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
@@ -934,6 +934,12 @@ class ChatOllama(BaseChatModel):
|
|||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Mapping[str, Any] | str]:
|
) -> AsyncIterator[Mapping[str, Any] | str]:
|
||||||
|
if not self._async_client:
|
||||||
|
msg = (
|
||||||
|
"Ollama async client is not initialized. "
|
||||||
|
"Make sure the model was properly constructed."
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg)
|
||||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||||
|
|
||||||
if chat_params["stream"]:
|
if chat_params["stream"]:
|
||||||
@@ -948,12 +954,17 @@ class ChatOllama(BaseChatModel):
|
|||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Mapping[str, Any] | str]:
|
) -> Iterator[Mapping[str, Any] | str]:
|
||||||
|
if not self._client:
|
||||||
|
msg = (
|
||||||
|
"Ollama sync client is not initialized. "
|
||||||
|
"Make sure the model was properly constructed."
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg)
|
||||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||||
|
|
||||||
if chat_params["stream"]:
|
if chat_params["stream"]:
|
||||||
if self._client:
|
|
||||||
yield from self._client.chat(**chat_params)
|
yield from self._client.chat(**chat_params)
|
||||||
elif self._client:
|
else:
|
||||||
yield self._client.chat(**chat_params)
|
yield self._client.chat(**chat_params)
|
||||||
|
|
||||||
def _chat_stream_with_aggregation(
|
def _chat_stream_with_aggregation(
|
||||||
|
|||||||
@@ -298,10 +298,10 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Embed search docs."""
|
"""Embed search docs."""
|
||||||
if not self._client:
|
if not self._client:
|
||||||
msg = (
|
msg = (
|
||||||
"Ollama client is not initialized. "
|
"Ollama sync client is not initialized. "
|
||||||
"Please ensure Ollama is running and the model is loaded."
|
"Make sure the model was properly constructed."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise RuntimeError(msg)
|
||||||
return self._client.embed(
|
return self._client.embed(
|
||||||
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
|
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
|
||||||
)["embeddings"]
|
)["embeddings"]
|
||||||
@@ -314,10 +314,10 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Embed search docs."""
|
"""Embed search docs."""
|
||||||
if not self._async_client:
|
if not self._async_client:
|
||||||
msg = (
|
msg = (
|
||||||
"Ollama client is not initialized. "
|
"Ollama async client is not initialized. "
|
||||||
"Please ensure Ollama is running and the model is loaded."
|
"Make sure the model was properly constructed."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise RuntimeError(msg)
|
||||||
return (
|
return (
|
||||||
await self._async_client.embed(
|
await self._async_client.embed(
|
||||||
self.model,
|
self.model,
|
||||||
|
|||||||
@@ -347,7 +347,12 @@ class OllamaLLM(BaseLLM):
|
|||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Mapping[str, Any] | str]:
|
) -> AsyncIterator[Mapping[str, Any] | str]:
|
||||||
if self._async_client:
|
if not self._async_client:
|
||||||
|
msg = (
|
||||||
|
"Ollama async client is not initialized. "
|
||||||
|
"Make sure the model was properly constructed."
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg)
|
||||||
async for part in await self._async_client.generate(
|
async for part in await self._async_client.generate(
|
||||||
**self._generate_params(prompt, stop=stop, **kwargs)
|
**self._generate_params(prompt, stop=stop, **kwargs)
|
||||||
):
|
):
|
||||||
@@ -359,7 +364,12 @@ class OllamaLLM(BaseLLM):
|
|||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Mapping[str, Any] | str]:
|
) -> Iterator[Mapping[str, Any] | str]:
|
||||||
if self._client:
|
if not self._client:
|
||||||
|
msg = (
|
||||||
|
"Ollama sync client is not initialized. "
|
||||||
|
"Make sure the model was properly constructed."
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg)
|
||||||
yield from self._client.generate(
|
yield from self._client.generate(
|
||||||
**self._generate_params(prompt, stop=stop, **kwargs)
|
**self._generate_params(prompt, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -449,6 +449,42 @@ def test_reasoning_param_passed_to_client() -> None:
|
|||||||
assert call_kwargs["think"] is True
|
assert call_kwargs["think"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_chat_stream_raises_when_client_none() -> None:
|
||||||
|
"""Test that _create_chat_stream raises RuntimeError when client is None."""
|
||||||
|
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
||||||
|
mock_client_class.return_value = MagicMock()
|
||||||
|
llm = ChatOllama(model="test-model")
|
||||||
|
# Force _client to None to simulate uninitialized state
|
||||||
|
llm._client = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="sync client is not initialized"):
|
||||||
|
list(llm._create_chat_stream([HumanMessage("Hello")]))
|
||||||
|
|
||||||
|
|
||||||
|
async def test_acreate_chat_stream_raises_when_client_none() -> None:
|
||||||
|
"""Test that _acreate_chat_stream raises RuntimeError when client is None."""
|
||||||
|
with patch("langchain_ollama.chat_models.AsyncClient") as mock_client_class:
|
||||||
|
mock_client_class.return_value = MagicMock()
|
||||||
|
llm = ChatOllama(model="test-model")
|
||||||
|
# Force _async_client to None to simulate uninitialized state
|
||||||
|
llm._async_client = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="async client is not initialized"):
|
||||||
|
async for _ in llm._acreate_chat_stream([HumanMessage("Hello")]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_raises_when_client_none() -> None:
|
||||||
|
"""Test that RuntimeError propagates through the public invoke() API."""
|
||||||
|
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
||||||
|
mock_client_class.return_value = MagicMock()
|
||||||
|
llm = ChatOllama(model="test-model")
|
||||||
|
llm._client = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="sync client is not initialized"):
|
||||||
|
llm.invoke([HumanMessage("Hello")])
|
||||||
|
|
||||||
|
|
||||||
def test_chat_ollama_ignores_strict_arg() -> None:
|
def test_chat_ollama_ignores_strict_arg() -> None:
|
||||||
"""Test that ChatOllama ignores the 'strict' argument."""
|
"""Test that ChatOllama ignores the 'strict' argument."""
|
||||||
response = [
|
response = [
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
"""Test embedding model integration."""
|
"""Test embedding model integration."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain_ollama.embeddings import OllamaEmbeddings
|
from langchain_ollama.embeddings import OllamaEmbeddings
|
||||||
|
|
||||||
@@ -50,3 +52,25 @@ def test_embed_documents_passes_options(mock_client_class: Any) -> None:
|
|||||||
options = call_args.kwargs["options"]
|
options = call_args.kwargs["options"]
|
||||||
assert options["num_gpu"] == 4
|
assert options["num_gpu"] == 4
|
||||||
assert options["temperature"] == 0.5
|
assert options["temperature"] == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_documents_raises_when_client_none() -> None:
|
||||||
|
"""Test that embed_documents raises RuntimeError when client is None."""
|
||||||
|
with patch("langchain_ollama.embeddings.Client") as mock_client_class:
|
||||||
|
mock_client_class.return_value = MagicMock()
|
||||||
|
embeddings = OllamaEmbeddings(model="test-model")
|
||||||
|
embeddings._client = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="sync client is not initialized"):
|
||||||
|
embeddings.embed_documents(["test"])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aembed_documents_raises_when_client_none() -> None:
|
||||||
|
"""Test that aembed_documents raises RuntimeError when async client is None."""
|
||||||
|
with patch("langchain_ollama.embeddings.AsyncClient") as mock_client_class:
|
||||||
|
mock_client_class.return_value = MagicMock()
|
||||||
|
embeddings = OllamaEmbeddings(model="test-model")
|
||||||
|
embeddings._async_client = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="async client is not initialized"):
|
||||||
|
await embeddings.aembed_documents(["test"])
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
"""Test Ollama Chat API wrapper."""
|
"""Test Ollama Chat API wrapper."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain_ollama import OllamaLLM
|
from langchain_ollama import OllamaLLM
|
||||||
|
|
||||||
@@ -65,3 +67,26 @@ def test_reasoning_aggregation() -> None:
|
|||||||
result.generations[0][0].generation_info["thinking"]
|
result.generations[0][0].generation_info["thinking"]
|
||||||
== "I am thinking. Still thinking."
|
== "I am thinking. Still thinking."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_generate_stream_raises_when_client_none() -> None:
|
||||||
|
"""Test that _create_generate_stream raises RuntimeError when client is None."""
|
||||||
|
with patch("langchain_ollama.llms.Client") as mock_client_class:
|
||||||
|
mock_client_class.return_value = MagicMock()
|
||||||
|
llm = OllamaLLM(model="test-model")
|
||||||
|
llm._client = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="sync client is not initialized"):
|
||||||
|
list(llm._create_generate_stream("Hello"))
|
||||||
|
|
||||||
|
|
||||||
|
async def test_acreate_generate_stream_raises_when_client_none() -> None:
|
||||||
|
"""Test that _acreate_generate_stream raises RuntimeError when client is None."""
|
||||||
|
with patch("langchain_ollama.llms.AsyncClient") as mock_client_class:
|
||||||
|
mock_client_class.return_value = MagicMock()
|
||||||
|
llm = OllamaLLM(model="test-model")
|
||||||
|
llm._async_client = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="async client is not initialized"):
|
||||||
|
async for _ in llm._acreate_generate_stream("Hello"):
|
||||||
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user