ollama: add validate_model_on_init, catch more errors (#31784)

* Ensure access to local model during `ChatOllama` instantiation
(#27720). This adds a new param `validate_model_on_init` (default:
`true`)
* Catch a few more errors from the Ollama client to assist users
This commit is contained in:
Mason Daugherty 2025-07-03 11:07:11 -04:00 committed by GitHub
parent 1a3a8db3c9
commit 572020c4d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 188 additions and 3 deletions

View File

@ -0,0 +1,37 @@
"""Utility functions for validating Ollama models."""
from httpx import ConnectError
from ollama import Client, ResponseError
def validate_model(client: Client, model_name: str) -> None:
"""Validate that a model exists in the Ollama instance.
Args:
client: The Ollama client.
model_name: The name of the model to validate.
Raises:
ValueError: If the model is not found or if there's a connection issue.
"""
try:
response = client.list()
model_names: list[str] = [model["name"] for model in response["models"]]
if not any(
model_name == m or m.startswith(f"{model_name}:") for m in model_names
):
raise ValueError(
f"Model `{model_name}` not found in Ollama. Please pull the "
f"model (using `ollama pull {model_name}`) or specify a valid "
f"model name. Available local models: {', '.join(model_names)}"
)
except ConnectError as e:
raise ValueError(
"Connection to Ollama failed. Please make sure Ollama is running "
f"and accessible at {client._client.base_url}. "
) from e
except ResponseError as e:
raise ValueError(
"Received an error from the Ollama API. "
"Please check your Ollama server logs."
) from e

View File

@ -55,6 +55,8 @@ from pydantic.json_schema import JsonSchemaValue
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self, is_typeddict
from ._utils import validate_model
DEFAULT_THINK_TOKEN_START: Final[str] = "<think>"
DEFAULT_THINK_TOKEN_END: Final[str] = "</think>"
@ -350,6 +352,9 @@ class ChatOllama(BaseChatModel):
model: str
"""Model name to use."""
validate_model_on_init: bool = False
"""Whether to validate the model exists in Ollama locally on initialization."""
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
"""Whether to extract the reasoning tokens in think blocks.
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
@ -529,6 +534,8 @@ class ChatOllama(BaseChatModel):
self._client = Client(host=self.base_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self
def _convert_messages_to_ollama_messages(
@ -1226,7 +1233,7 @@ class ChatOllama(BaseChatModel):
"schema": schema,
},
)
output_parser = PydanticOutputParser(pydantic_object=schema)
output_parser = PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
else:
if is_typeddict(schema):
response_format = convert_to_json_schema(schema)

View File

@ -12,6 +12,8 @@ from pydantic import (
)
from typing_extensions import Self
from ._utils import validate_model
class OllamaEmbeddings(BaseModel, Embeddings):
"""Ollama embedding model integration.
@ -123,6 +125,9 @@ class OllamaEmbeddings(BaseModel, Embeddings):
model: str
"""Model name to use."""
validate_model_on_init: bool = False
"""Whether to validate the model exists in ollama locally on initialization."""
base_url: Optional[str] = None
"""Base url the model is hosted under."""
@ -259,6 +264,8 @@ class OllamaEmbeddings(BaseModel, Embeddings):
self._client = Client(host=self.base_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self
def embed_documents(self, texts: list[str]) -> list[list[float]]:

View File

@ -18,6 +18,8 @@ from ollama import AsyncClient, Client, Options
from pydantic import PrivateAttr, model_validator
from typing_extensions import Self
from ._utils import validate_model
class OllamaLLM(BaseLLM):
"""OllamaLLM large language models.
@ -34,6 +36,9 @@ class OllamaLLM(BaseLLM):
model: str
"""Model name to use."""
validate_model_on_init: bool = False
"""Whether to validate the model exists in ollama locally on initialization."""
mirostat: Optional[int] = None
"""Enable Mirostat sampling for controlling perplexity.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
@ -215,6 +220,8 @@ class OllamaLLM(BaseLLM):
self._client = Client(host=self.base_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self
async def _acreate_generate_stream(

View File

@ -1,8 +1,13 @@
"""Test chat model integration using standard integration tests."""
from unittest.mock import MagicMock, patch
import pytest
from httpx import ConnectError
from langchain_core.language_models import BaseChatModel
from langchain_tests.integration_tests import ChatModelIntegrationTests
from ollama import ResponseError
from pydantic import ValidationError
from langchain_ollama.chat_models import ChatOllama
@ -47,3 +52,29 @@ class TestChatOllama(ChatModelIntegrationTests):
)
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
await super().test_tool_calling_async(model)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_model_not_found(self, mock_list: MagicMock) -> None:
"""Test that a ValueError is raised when the model is not found."""
mock_list.side_effect = ValueError("Test model not found")
with pytest.raises(ValueError) as excinfo:
ChatOllama(model="non-existent-model", validate_model_on_init=True)
assert "Test model not found" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_connection_error(self, mock_list: MagicMock) -> None:
"""Test that a ValidationError is raised on connect failure during init."""
mock_list.side_effect = ConnectError("Test connection error")
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "not found in Ollama" in str(excinfo.value)
@patch("langchain_ollama.chat_models.Client.list")
def test_init_response_error(self, mock_list: MagicMock) -> None:
"""Test that a ResponseError is raised."""
mock_list.side_effect = ResponseError("Test response error")
with pytest.raises(ValidationError) as excinfo:
ChatOllama(model="any-model", validate_model_on_init=True)
assert "Received an error from the Ollama API" in str(excinfo.value)

View File

@ -4,6 +4,7 @@ import json
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import patch
import pytest
from httpx import Client, Request, Response
@ -12,6 +13,8 @@ from langchain_tests.unit_tests import ChatModelUnitTests
from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_call
MODEL_NAME = "llama3.1"
class TestChatOllama(ChatModelUnitTests):
@property
@ -49,7 +52,7 @@ def test_arbitrary_roles_accepted_in_chatmessages(
llm = ChatOllama(
base_url="http://whocares:11434",
model="granite3.2",
model=MODEL_NAME,
verbose=True,
format=None,
)
@ -64,3 +67,20 @@ def test_arbitrary_roles_accepted_in_chatmessages(
]
llm.invoke(messages)
@patch("langchain_ollama.chat_models.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
ChatOllama(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
ChatOllama(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
ChatOllama(model=MODEL_NAME)
mock_validate_model.assert_not_called()

View File

@ -1,8 +1,30 @@
"""Test embedding model integration."""
from typing import Any
from unittest.mock import patch
from langchain_ollama.embeddings import OllamaEmbeddings
MODEL_NAME = "llama3.1"
def test_initialization() -> None:
"""Test embedding model initialization."""
OllamaEmbeddings(model="llama3", keep_alive=1)
@patch("langchain_ollama.embeddings.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
OllamaEmbeddings(model=MODEL_NAME)
mock_validate_model.assert_not_called()

View File

@ -1,7 +1,12 @@
"""Test Ollama Chat API wrapper."""
from typing import Any
from unittest.mock import patch
from langchain_ollama import OllamaLLM
MODEL_NAME = "llama3.1"
def test_initialization() -> None:
"""Test integration initialization."""
@ -26,3 +31,20 @@ def test_model_params() -> None:
"ls_model_name": "llama3",
"ls_max_tokens": 3,
}
@patch("langchain_ollama.llms.validate_model")
def test_validate_model_on_init(mock_validate_model: Any) -> None:
"""Test that the model is validated on initialization when requested."""
# Test that validate_model is called when validate_model_on_init=True
OllamaLLM(model=MODEL_NAME, validate_model_on_init=True)
mock_validate_model.assert_called_once()
mock_validate_model.reset_mock()
# Test that validate_model is NOT called when validate_model_on_init=False
OllamaLLM(model=MODEL_NAME, validate_model_on_init=False)
mock_validate_model.assert_not_called()
# Test that validate_model is NOT called by default
OllamaLLM(model=MODEL_NAME)
mock_validate_model.assert_not_called()

View File

@ -363,7 +363,7 @@ typing = [
[[package]]
name = "langchain-ollama"
version = "0.3.4"
version = "0.3.3"
source = { editable = "." }
dependencies = [
{ name = "langchain-core" },