ollama: add validate_model_on_init, catch more errors (#31784)

* Ensure access to local model during `ChatOllama` instantiation
(#27720). This adds a new param `validate_model_on_init` (default:
`true`)
* Catch a few more errors from the Ollama client to assist users
This commit is contained in:
Mason Daugherty 2025-07-03 11:07:11 -04:00 committed by GitHub
parent 1a3a8db3c9
commit 572020c4d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 188 additions and 3 deletions

View File

@ -0,0 +1,37 @@
"""Utility functions for validating Ollama models."""
from httpx import ConnectError
from ollama import Client, ResponseError
def validate_model(client: Client, model_name: str) -> None:
"""Validate that a model exists in the Ollama instance.
Args:
client: The Ollama client.
model_name: The name of the model to validate.
Raises:
ValueError: If the model is not found or if there's a connection issue.
"""
try:
response = client.list()
model_names: list[str] = [model["name"] for model in response["models"]]
if not any(
model_name == m or m.startswith(f"{model_name}:") for m in model_names
):
raise ValueError(
f"Model `{model_name}` not found in Ollama. Please pull the "
f"model (using `ollama pull {model_name}`) or specify a valid "
f"model name. Available local models: {', '.join(model_names)}"
)
except ConnectError as e:
raise ValueError(
"Connection to Ollama failed. Please make sure Ollama is running "
f"and accessible at {client._client.base_url}. "
) from e
except ResponseError as e:
raise ValueError(
"Received an error from the Ollama API. "
"Please check your Ollama server logs."
) from e

View File

@ -55,6 +55,8 @@ from pydantic.json_schema import JsonSchemaValue
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self, is_typeddict from typing_extensions import Self, is_typeddict
from ._utils import validate_model
DEFAULT_THINK_TOKEN_START: Final[str] = "<think>" DEFAULT_THINK_TOKEN_START: Final[str] = "<think>"
DEFAULT_THINK_TOKEN_END: Final[str] = "</think>" DEFAULT_THINK_TOKEN_END: Final[str] = "</think>"
@ -350,6 +352,9 @@ class ChatOllama(BaseChatModel):
model: str model: str
"""Model name to use.""" """Model name to use."""
validate_model_on_init: bool = False
"""Whether to validate the model exists in Ollama locally on initialization."""
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
"""Whether to extract the reasoning tokens in think blocks. """Whether to extract the reasoning tokens in think blocks.
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`. Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
@ -529,6 +534,8 @@ class ChatOllama(BaseChatModel):
self._client = Client(host=self.base_url, **sync_client_kwargs) self._client = Client(host=self.base_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self return self
def _convert_messages_to_ollama_messages( def _convert_messages_to_ollama_messages(
@ -1226,7 +1233,7 @@ class ChatOllama(BaseChatModel):
"schema": schema, "schema": schema,
}, },
) )
output_parser = PydanticOutputParser(pydantic_object=schema) output_parser = PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
else: else:
if is_typeddict(schema): if is_typeddict(schema):
response_format = convert_to_json_schema(schema) response_format = convert_to_json_schema(schema)

View File

@ -12,6 +12,8 @@ from pydantic import (
) )
from typing_extensions import Self from typing_extensions import Self
from ._utils import validate_model
class OllamaEmbeddings(BaseModel, Embeddings): class OllamaEmbeddings(BaseModel, Embeddings):
"""Ollama embedding model integration. """Ollama embedding model integration.
@ -123,6 +125,9 @@ class OllamaEmbeddings(BaseModel, Embeddings):
model: str model: str
"""Model name to use.""" """Model name to use."""
validate_model_on_init: bool = False
"""Whether to validate the model exists in ollama locally on initialization."""
base_url: Optional[str] = None base_url: Optional[str] = None
"""Base url the model is hosted under.""" """Base url the model is hosted under."""
@ -259,6 +264,8 @@ class OllamaEmbeddings(BaseModel, Embeddings):
self._client = Client(host=self.base_url, **sync_client_kwargs) self._client = Client(host=self.base_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self return self
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:

View File

@ -18,6 +18,8 @@ from ollama import AsyncClient, Client, Options
from pydantic import PrivateAttr, model_validator from pydantic import PrivateAttr, model_validator
from typing_extensions import Self from typing_extensions import Self
from ._utils import validate_model
class OllamaLLM(BaseLLM): class OllamaLLM(BaseLLM):
"""OllamaLLM large language models. """OllamaLLM large language models.
@ -34,6 +36,9 @@ class OllamaLLM(BaseLLM):
model: str model: str
"""Model name to use.""" """Model name to use."""
validate_model_on_init: bool = False
"""Whether to validate the model exists in ollama locally on initialization."""
mirostat: Optional[int] = None mirostat: Optional[int] = None
"""Enable Mirostat sampling for controlling perplexity. """Enable Mirostat sampling for controlling perplexity.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)""" (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
@ -215,6 +220,8 @@ class OllamaLLM(BaseLLM):
self._client = Client(host=self.base_url, **sync_client_kwargs) self._client = Client(host=self.base_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self return self
async def _acreate_generate_stream( async def _acreate_generate_stream(

View File

@ -1,8 +1,13 @@
"""Test chat model integration using standard integration tests.""" """Test chat model integration using standard integration tests."""
from unittest.mock import MagicMock, patch
import pytest import pytest
from httpx import ConnectError
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.integration_tests import ChatModelIntegrationTests from langchain_tests.integration_tests import ChatModelIntegrationTests
from ollama import ResponseError
from pydantic import ValidationError
from langchain_ollama.chat_models import ChatOllama from langchain_ollama.chat_models import ChatOllama
@ -47,3 +52,29 @@ class TestChatOllama(ChatModelIntegrationTests):
) )
async def test_tool_calling_async(self, model: BaseChatModel) -> None: async def test_tool_calling_async(self, model: BaseChatModel) -> None:
await super().test_tool_calling_async(model) await super().test_tool_calling_async(model)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_model_not_found(self, mock_list: MagicMock) -> None:
"""Test that a ValueError is raised when the model is not found."""
mock_list.side_effect = ValueError("Test model not found")
with pytest.raises(ValueError) as excinfo:
ChatOllama(model="non-existent-model", validate_model_on_init=True)
assert "Test model not found" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_connection_error(self, mock_list: MagicMock) -> None:
"""Test that a ValidationError is raised on connect failure during init."""
mock_list.side_effect = ConnectError("Test connection error")
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "not found in Ollama" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_response_error(self, mock_list: MagicMock) -> None:
"""Test that a ResponseError is raised."""
mock_list.side_effect = ResponseError("Test response error")
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "Received an error from the Ollama API" in str(excinfo.value)

View File

@ -4,6 +4,7 @@ import json
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from typing import Any
from unittest.mock import patch
import pytest import pytest
from httpx import Client, Request, Response from httpx import Client, Request, Response
@ -12,6 +13,8 @@ from langchain_tests.unit_tests import ChatModelUnitTests
from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_call from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_call
MODEL_NAME = "llama3.1"
class TestChatOllama(ChatModelUnitTests): class TestChatOllama(ChatModelUnitTests):
@property @property
@ -49,7 +52,7 @@ def test_arbitrary_roles_accepted_in_chatmessages(
llm = ChatOllama( llm = ChatOllama(
base_url="http://whocares:11434", base_url="http://whocares:11434",
model="granite3.2", model=MODEL_NAME,
verbose=True, verbose=True,
format=None, format=None,
) )
@ -64,3 +67,20 @@ def test_arbitrary_roles_accepted_in_chatmessages(
] ]
llm.invoke(messages) llm.invoke(messages)
@patch("langchain_ollama.chat_models.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
ChatOllama(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
ChatOllama(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
ChatOllama(model=MODEL_NAME)
mock_validate_model.assert_not_called()

View File

@ -1,8 +1,30 @@
"""Test embedding model integration.""" """Test embedding model integration."""
from typing import Any
from unittest.mock import patch
from langchain_ollama.embeddings import OllamaEmbeddings from langchain_ollama.embeddings import OllamaEmbeddings
MODEL_NAME = "llama3.1"
def test_initialization() -> None: def test_initialization() -> None:
"""Test embedding model initialization.""" """Test embedding model initialization."""
OllamaEmbeddings(model="llama3", keep_alive=1) OllamaEmbeddings(model="llama3", keep_alive=1)
@patch("langchain_ollama.embeddings.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
OllamaEmbeddings(model=MODEL_NAME)
mock_validate_model.assert_not_called()

View File

@ -1,7 +1,12 @@
"""Test Ollama Chat API wrapper.""" """Test Ollama Chat API wrapper."""
from typing import Any
from unittest.mock import patch
from langchain_ollama import OllamaLLM from langchain_ollama import OllamaLLM
MODEL_NAME = "llama3.1"
def test_initialization() -> None: def test_initialization() -> None:
"""Test integration initialization.""" """Test integration initialization."""
@ -26,3 +31,20 @@ def test_model_params() -> None:
"ls_model_name": "llama3", "ls_model_name": "llama3",
"ls_max_tokens": 3, "ls_max_tokens": 3,
} }
@patch("langchain_ollama.llms.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
OllamaLLM(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
OllamaLLM(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
OllamaLLM(model=MODEL_NAME)
mock_validate_model.assert_not_called()

View File

@ -363,7 +363,7 @@ typing = [
[[package]] [[package]]
name = "langchain-ollama" name = "langchain-ollama"
version = "0.3.4" version = "0.3.3"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "langchain-core" }, { name = "langchain-core" },