mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 22:45:49 +00:00
ollama: add validate_model_on_init
, catch more errors (#31784)
* Ensure access to local model during `ChatOllama` instantiation (#27720). This adds a new param `validate_model_on_init` (default: `true`) * Catch a few more errors from the Ollama client to assist users
This commit is contained in:
parent
1a3a8db3c9
commit
572020c4d8
37
libs/partners/ollama/langchain_ollama/_utils.py
Normal file
37
libs/partners/ollama/langchain_ollama/_utils.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""Utility functions for validating Ollama models."""
|
||||
|
||||
from httpx import ConnectError
|
||||
from ollama import Client, ResponseError
|
||||
|
||||
|
||||
def validate_model(client: Client, model_name: str) -> None:
|
||||
"""Validate that a model exists in the Ollama instance.
|
||||
|
||||
Args:
|
||||
client: The Ollama client.
|
||||
model_name: The name of the model to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not found or if there's a connection issue.
|
||||
"""
|
||||
try:
|
||||
response = client.list()
|
||||
model_names: list[str] = [model["name"] for model in response["models"]]
|
||||
if not any(
|
||||
model_name == m or m.startswith(f"{model_name}:") for m in model_names
|
||||
):
|
||||
raise ValueError(
|
||||
f"Model `{model_name}` not found in Ollama. Please pull the "
|
||||
f"model (using `ollama pull {model_name}`) or specify a valid "
|
||||
f"model name. Available local models: {', '.join(model_names)}"
|
||||
)
|
||||
except ConnectError as e:
|
||||
raise ValueError(
|
||||
"Connection to Ollama failed. Please make sure Ollama is running "
|
||||
f"and accessible at {client._client.base_url}. "
|
||||
) from e
|
||||
except ResponseError as e:
|
||||
raise ValueError(
|
||||
"Received an error from the Ollama API. "
|
||||
"Please check your Ollama server logs."
|
||||
) from e
|
@ -55,6 +55,8 @@ from pydantic.json_schema import JsonSchemaValue
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import Self, is_typeddict
|
||||
|
||||
from ._utils import validate_model
|
||||
|
||||
DEFAULT_THINK_TOKEN_START: Final[str] = "<think>"
|
||||
DEFAULT_THINK_TOKEN_END: Final[str] = "</think>"
|
||||
|
||||
@ -350,6 +352,9 @@ class ChatOllama(BaseChatModel):
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
validate_model_on_init: bool = False
|
||||
"""Whether to validate the model exists in Ollama locally on initialization."""
|
||||
|
||||
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
|
||||
"""Whether to extract the reasoning tokens in think blocks.
|
||||
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
|
||||
@ -529,6 +534,8 @@ class ChatOllama(BaseChatModel):
|
||||
|
||||
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
||||
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
||||
if self.validate_model_on_init:
|
||||
validate_model(self._client, self.model)
|
||||
return self
|
||||
|
||||
def _convert_messages_to_ollama_messages(
|
||||
@ -1226,7 +1233,7 @@ class ChatOllama(BaseChatModel):
|
||||
"schema": schema,
|
||||
},
|
||||
)
|
||||
output_parser = PydanticOutputParser(pydantic_object=schema)
|
||||
output_parser = PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||
else:
|
||||
if is_typeddict(schema):
|
||||
response_format = convert_to_json_schema(schema)
|
||||
|
@ -12,6 +12,8 @@ from pydantic import (
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._utils import validate_model
|
||||
|
||||
|
||||
class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
"""Ollama embedding model integration.
|
||||
@ -123,6 +125,9 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
validate_model_on_init: bool = False
|
||||
"""Whether to validate the model exists in ollama locally on initialization."""
|
||||
|
||||
base_url: Optional[str] = None
|
||||
"""Base url the model is hosted under."""
|
||||
|
||||
@ -259,6 +264,8 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
||||
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
||||
if self.validate_model_on_init:
|
||||
validate_model(self._client, self.model)
|
||||
return self
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
|
@ -18,6 +18,8 @@ from ollama import AsyncClient, Client, Options
|
||||
from pydantic import PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._utils import validate_model
|
||||
|
||||
|
||||
class OllamaLLM(BaseLLM):
|
||||
"""OllamaLLM large language models.
|
||||
@ -34,6 +36,9 @@ class OllamaLLM(BaseLLM):
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
validate_model_on_init: bool = False
|
||||
"""Whether to validate the model exists in ollama locally on initialization."""
|
||||
|
||||
mirostat: Optional[int] = None
|
||||
"""Enable Mirostat sampling for controlling perplexity.
|
||||
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
|
||||
@ -215,6 +220,8 @@ class OllamaLLM(BaseLLM):
|
||||
|
||||
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
||||
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
||||
if self.validate_model_on_init:
|
||||
validate_model(self._client, self.model)
|
||||
return self
|
||||
|
||||
async def _acreate_generate_stream(
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,8 +1,13 @@
|
||||
"""Test chat model integration using standard integration tests."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ConnectError
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||
from ollama import ResponseError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
|
||||
@ -47,3 +52,29 @@ class TestChatOllama(ChatModelIntegrationTests):
|
||||
)
|
||||
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
|
||||
await super().test_tool_calling_async(model)
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client.list")
|
||||
def test_init_model_not_found(self, mock_list: MagicMock) -> None:
|
||||
"""Test that a ValueError is raised when the model is not found."""
|
||||
mock_list.side_effect = ValueError("Test model not found")
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
ChatOllama(model="non-existent-model", validate_model_on_init=True)
|
||||
assert "Test model not found" in str(excinfo.value)
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client.list")
|
||||
def test_init_connection_error(self, mock_list: MagicMock) -> None:
|
||||
"""Test that a ValidationError is raised on connect failure during init."""
|
||||
mock_list.side_effect = ConnectError("Test connection error")
|
||||
|
||||
with pytest.raises(ValidationError) as excinfo:
|
||||
ChatOllama(model="any-model", validate_model_on_init=True)
|
||||
assert "not found in Ollama" in str(excinfo.value)
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client.list")
|
||||
def test_init_response_error(self, mock_list: MagicMock) -> None:
|
||||
"""Test that a ResponseError is raised."""
|
||||
mock_list.side_effect = ResponseError("Test response error")
|
||||
|
||||
with pytest.raises(ValidationError) as excinfo:
|
||||
ChatOllama(model="any-model", validate_model_on_init=True)
|
||||
assert "Received an error from the Ollama API" in str(excinfo.value)
|
||||
|
@ -4,6 +4,7 @@ import json
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from httpx import Client, Request, Response
|
||||
@ -12,6 +13,8 @@ from langchain_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_call
|
||||
|
||||
MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
class TestChatOllama(ChatModelUnitTests):
|
||||
@property
|
||||
@ -49,7 +52,7 @@ def test_arbitrary_roles_accepted_in_chatmessages(
|
||||
|
||||
llm = ChatOllama(
|
||||
base_url="http://whocares:11434",
|
||||
model="granite3.2",
|
||||
model=MODEL_NAME,
|
||||
verbose=True,
|
||||
format=None,
|
||||
)
|
||||
@ -64,3 +67,20 @@ def test_arbitrary_roles_accepted_in_chatmessages(
|
||||
]
|
||||
|
||||
llm.invoke(messages)
|
||||
|
||||
|
||||
@patch("langchain_ollama.chat_models.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
ChatOllama(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
ChatOllama(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
ChatOllama(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
@ -1,8 +1,30 @@
|
||||
"""Test embedding model integration."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_ollama.embeddings import OllamaEmbeddings
|
||||
|
||||
MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
OllamaEmbeddings(model="llama3", keep_alive=1)
|
||||
|
||||
|
||||
@patch("langchain_ollama.embeddings.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
OllamaEmbeddings(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
@ -1,7 +1,12 @@
|
||||
"""Test Ollama Chat API wrapper."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_ollama import OllamaLLM
|
||||
|
||||
MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test integration initialization."""
|
||||
@ -26,3 +31,20 @@ def test_model_params() -> None:
|
||||
"ls_model_name": "llama3",
|
||||
"ls_max_tokens": 3,
|
||||
}
|
||||
|
||||
|
||||
@patch("langchain_ollama.llms.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
OllamaLLM(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
OllamaLLM(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
OllamaLLM(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
@ -363,7 +363,7 @@ typing = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-ollama"
|
||||
version = "0.3.4"
|
||||
version = "0.3.3"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
|
Loading…
Reference in New Issue
Block a user