mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 06:24:47 +00:00
ollama: add validate_model_on_init
, catch more errors (#31784)
* Ensure access to local model during `ChatOllama` instantiation (#27720). This adds a new param `validate_model_on_init` (default: `true`) * Catch a few more errors from the Ollama client to assist users
This commit is contained in:
parent
1a3a8db3c9
commit
572020c4d8
37
libs/partners/ollama/langchain_ollama/_utils.py
Normal file
37
libs/partners/ollama/langchain_ollama/_utils.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
"""Utility functions for validating Ollama models."""
|
||||||
|
|
||||||
|
from httpx import ConnectError
|
||||||
|
from ollama import Client, ResponseError
|
||||||
|
|
||||||
|
|
||||||
|
def validate_model(client: Client, model_name: str) -> None:
|
||||||
|
"""Validate that a model exists in the Ollama instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: The Ollama client.
|
||||||
|
model_name: The name of the model to validate.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model is not found or if there's a connection issue.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = client.list()
|
||||||
|
model_names: list[str] = [model["name"] for model in response["models"]]
|
||||||
|
if not any(
|
||||||
|
model_name == m or m.startswith(f"{model_name}:") for m in model_names
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Model `{model_name}` not found in Ollama. Please pull the "
|
||||||
|
f"model (using `ollama pull {model_name}`) or specify a valid "
|
||||||
|
f"model name. Available local models: {', '.join(model_names)}"
|
||||||
|
)
|
||||||
|
except ConnectError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Connection to Ollama failed. Please make sure Ollama is running "
|
||||||
|
f"and accessible at {client._client.base_url}. "
|
||||||
|
) from e
|
||||||
|
except ResponseError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Received an error from the Ollama API. "
|
||||||
|
"Please check your Ollama server logs."
|
||||||
|
) from e
|
@ -55,6 +55,8 @@ from pydantic.json_schema import JsonSchemaValue
|
|||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from typing_extensions import Self, is_typeddict
|
from typing_extensions import Self, is_typeddict
|
||||||
|
|
||||||
|
from ._utils import validate_model
|
||||||
|
|
||||||
DEFAULT_THINK_TOKEN_START: Final[str] = "<think>"
|
DEFAULT_THINK_TOKEN_START: Final[str] = "<think>"
|
||||||
DEFAULT_THINK_TOKEN_END: Final[str] = "</think>"
|
DEFAULT_THINK_TOKEN_END: Final[str] = "</think>"
|
||||||
|
|
||||||
@ -350,6 +352,9 @@ class ChatOllama(BaseChatModel):
|
|||||||
model: str
|
model: str
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
|
||||||
|
validate_model_on_init: bool = False
|
||||||
|
"""Whether to validate the model exists in Ollama locally on initialization."""
|
||||||
|
|
||||||
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
|
extract_reasoning: Optional[Union[bool, tuple[str, str]]] = False
|
||||||
"""Whether to extract the reasoning tokens in think blocks.
|
"""Whether to extract the reasoning tokens in think blocks.
|
||||||
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
|
Extracts `chunk.content` to `chunk.additional_kwargs.reasoning_content`.
|
||||||
@ -529,6 +534,8 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
||||||
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
||||||
|
if self.validate_model_on_init:
|
||||||
|
validate_model(self._client, self.model)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _convert_messages_to_ollama_messages(
|
def _convert_messages_to_ollama_messages(
|
||||||
@ -1226,7 +1233,7 @@ class ChatOllama(BaseChatModel):
|
|||||||
"schema": schema,
|
"schema": schema,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
output_parser = PydanticOutputParser(pydantic_object=schema)
|
output_parser = PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
if is_typeddict(schema):
|
if is_typeddict(schema):
|
||||||
response_format = convert_to_json_schema(schema)
|
response_format = convert_to_json_schema(schema)
|
||||||
|
@ -12,6 +12,8 @@ from pydantic import (
|
|||||||
)
|
)
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from ._utils import validate_model
|
||||||
|
|
||||||
|
|
||||||
class OllamaEmbeddings(BaseModel, Embeddings):
|
class OllamaEmbeddings(BaseModel, Embeddings):
|
||||||
"""Ollama embedding model integration.
|
"""Ollama embedding model integration.
|
||||||
@ -123,6 +125,9 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
model: str
|
model: str
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
|
||||||
|
validate_model_on_init: bool = False
|
||||||
|
"""Whether to validate the model exists in ollama locally on initialization."""
|
||||||
|
|
||||||
base_url: Optional[str] = None
|
base_url: Optional[str] = None
|
||||||
"""Base url the model is hosted under."""
|
"""Base url the model is hosted under."""
|
||||||
|
|
||||||
@ -259,6 +264,8 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
||||||
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
||||||
|
if self.validate_model_on_init:
|
||||||
|
validate_model(self._client, self.model)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
@ -18,6 +18,8 @@ from ollama import AsyncClient, Client, Options
|
|||||||
from pydantic import PrivateAttr, model_validator
|
from pydantic import PrivateAttr, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from ._utils import validate_model
|
||||||
|
|
||||||
|
|
||||||
class OllamaLLM(BaseLLM):
|
class OllamaLLM(BaseLLM):
|
||||||
"""OllamaLLM large language models.
|
"""OllamaLLM large language models.
|
||||||
@ -34,6 +36,9 @@ class OllamaLLM(BaseLLM):
|
|||||||
model: str
|
model: str
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
|
||||||
|
validate_model_on_init: bool = False
|
||||||
|
"""Whether to validate the model exists in ollama locally on initialization."""
|
||||||
|
|
||||||
mirostat: Optional[int] = None
|
mirostat: Optional[int] = None
|
||||||
"""Enable Mirostat sampling for controlling perplexity.
|
"""Enable Mirostat sampling for controlling perplexity.
|
||||||
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
|
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
|
||||||
@ -215,6 +220,8 @@ class OllamaLLM(BaseLLM):
|
|||||||
|
|
||||||
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
||||||
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
||||||
|
if self.validate_model_on_init:
|
||||||
|
validate_model(self._client, self.model)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def _acreate_generate_stream(
|
async def _acreate_generate_stream(
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1,8 +1,13 @@
|
|||||||
"""Test chat model integration using standard integration tests."""
|
"""Test chat model integration using standard integration tests."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from httpx import ConnectError
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||||
|
from ollama import ResponseError
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from langchain_ollama.chat_models import ChatOllama
|
from langchain_ollama.chat_models import ChatOllama
|
||||||
|
|
||||||
@ -47,3 +52,29 @@ class TestChatOllama(ChatModelIntegrationTests):
|
|||||||
)
|
)
|
||||||
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
|
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
|
||||||
await super().test_tool_calling_async(model)
|
await super().test_tool_calling_async(model)
|
||||||
|
|
||||||
|
@patch("langchain_ollama.chat_models.Client.list")
|
||||||
|
def test_init_model_not_found(self, mock_list: MagicMock) -> None:
|
||||||
|
"""Test that a ValueError is raised when the model is not found."""
|
||||||
|
mock_list.side_effect = ValueError("Test model not found")
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
ChatOllama(model="non-existent-model", validate_model_on_init=True)
|
||||||
|
assert "Test model not found" in str(excinfo.value)
|
||||||
|
|
||||||
|
@patch("langchain_ollama.chat_models.Client.list")
|
||||||
|
def test_init_connection_error(self, mock_list: MagicMock) -> None:
|
||||||
|
"""Test that a ValidationError is raised on connect failure during init."""
|
||||||
|
mock_list.side_effect = ConnectError("Test connection error")
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as excinfo:
|
||||||
|
ChatOllama(model="any-model", validate_model_on_init=True)
|
||||||
|
assert "not found in Ollama" in str(excinfo.value)
|
||||||
|
|
||||||
|
@patch("langchain_ollama.chat_models.Client.list")
|
||||||
|
def test_init_response_error(self, mock_list: MagicMock) -> None:
|
||||||
|
"""Test that a ResponseError is raised."""
|
||||||
|
mock_list.side_effect = ResponseError("Test response error")
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as excinfo:
|
||||||
|
ChatOllama(model="any-model", validate_model_on_init=True)
|
||||||
|
assert "Received an error from the Ollama API" in str(excinfo.value)
|
||||||
|
@ -4,6 +4,7 @@ import json
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import Client, Request, Response
|
from httpx import Client, Request, Response
|
||||||
@ -12,6 +13,8 @@ from langchain_tests.unit_tests import ChatModelUnitTests
|
|||||||
|
|
||||||
from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_call
|
from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_call
|
||||||
|
|
||||||
|
MODEL_NAME = "llama3.1"
|
||||||
|
|
||||||
|
|
||||||
class TestChatOllama(ChatModelUnitTests):
|
class TestChatOllama(ChatModelUnitTests):
|
||||||
@property
|
@property
|
||||||
@ -49,7 +52,7 @@ def test_arbitrary_roles_accepted_in_chatmessages(
|
|||||||
|
|
||||||
llm = ChatOllama(
|
llm = ChatOllama(
|
||||||
base_url="http://whocares:11434",
|
base_url="http://whocares:11434",
|
||||||
model="granite3.2",
|
model=MODEL_NAME,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
format=None,
|
format=None,
|
||||||
)
|
)
|
||||||
@ -64,3 +67,20 @@ def test_arbitrary_roles_accepted_in_chatmessages(
|
|||||||
]
|
]
|
||||||
|
|
||||||
llm.invoke(messages)
|
llm.invoke(messages)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("langchain_ollama.chat_models.validate_model")
|
||||||
|
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||||
|
"""Test that the model is validated on initialization when requested."""
|
||||||
|
# Test that validate_model is called when validate_model_on_init=True
|
||||||
|
ChatOllama(model=MODEL_NAME, validate_model_on_init=True)
|
||||||
|
mock_validate_model.assert_called_once()
|
||||||
|
mock_validate_model.reset_mock()
|
||||||
|
|
||||||
|
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||||
|
ChatOllama(model=MODEL_NAME, validate_model_on_init=False)
|
||||||
|
mock_validate_model.assert_not_called()
|
||||||
|
|
||||||
|
# Test that validate_model is NOT called by default
|
||||||
|
ChatOllama(model=MODEL_NAME)
|
||||||
|
mock_validate_model.assert_not_called()
|
||||||
|
@ -1,8 +1,30 @@
|
|||||||
"""Test embedding model integration."""
|
"""Test embedding model integration."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from langchain_ollama.embeddings import OllamaEmbeddings
|
from langchain_ollama.embeddings import OllamaEmbeddings
|
||||||
|
|
||||||
|
MODEL_NAME = "llama3.1"
|
||||||
|
|
||||||
|
|
||||||
def test_initialization() -> None:
|
def test_initialization() -> None:
|
||||||
"""Test embedding model initialization."""
|
"""Test embedding model initialization."""
|
||||||
OllamaEmbeddings(model="llama3", keep_alive=1)
|
OllamaEmbeddings(model="llama3", keep_alive=1)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("langchain_ollama.embeddings.validate_model")
|
||||||
|
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||||
|
"""Test that the model is validated on initialization when requested."""
|
||||||
|
# Test that validate_model is called when validate_model_on_init=True
|
||||||
|
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=True)
|
||||||
|
mock_validate_model.assert_called_once()
|
||||||
|
mock_validate_model.reset_mock()
|
||||||
|
|
||||||
|
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||||
|
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=False)
|
||||||
|
mock_validate_model.assert_not_called()
|
||||||
|
|
||||||
|
# Test that validate_model is NOT called by default
|
||||||
|
OllamaEmbeddings(model=MODEL_NAME)
|
||||||
|
mock_validate_model.assert_not_called()
|
||||||
|
@ -1,7 +1,12 @@
|
|||||||
"""Test Ollama Chat API wrapper."""
|
"""Test Ollama Chat API wrapper."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from langchain_ollama import OllamaLLM
|
from langchain_ollama import OllamaLLM
|
||||||
|
|
||||||
|
MODEL_NAME = "llama3.1"
|
||||||
|
|
||||||
|
|
||||||
def test_initialization() -> None:
|
def test_initialization() -> None:
|
||||||
"""Test integration initialization."""
|
"""Test integration initialization."""
|
||||||
@ -26,3 +31,20 @@ def test_model_params() -> None:
|
|||||||
"ls_model_name": "llama3",
|
"ls_model_name": "llama3",
|
||||||
"ls_max_tokens": 3,
|
"ls_max_tokens": 3,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("langchain_ollama.llms.validate_model")
|
||||||
|
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||||
|
"""Test that the model is validated on initialization when requested."""
|
||||||
|
# Test that validate_model is called when validate_model_on_init=True
|
||||||
|
OllamaLLM(model=MODEL_NAME, validate_model_on_init=True)
|
||||||
|
mock_validate_model.assert_called_once()
|
||||||
|
mock_validate_model.reset_mock()
|
||||||
|
|
||||||
|
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||||
|
OllamaLLM(model=MODEL_NAME, validate_model_on_init=False)
|
||||||
|
mock_validate_model.assert_not_called()
|
||||||
|
|
||||||
|
# Test that validate_model is NOT called by default
|
||||||
|
OllamaLLM(model=MODEL_NAME)
|
||||||
|
mock_validate_model.assert_not_called()
|
||||||
|
@ -363,7 +363,7 @@ typing = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-ollama"
|
name = "langchain-ollama"
|
||||||
version = "0.3.4"
|
version = "0.3.3"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "langchain-core" },
|
{ name = "langchain-core" },
|
||||||
|
Loading…
Reference in New Issue
Block a user