mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
ollama: add validate_model_on_init, catch more errors (#31784)
* Ensure access to local model during `ChatOllama` instantiation (#27720). This adds a new param `validate_model_on_init` (default: `true`) * Catch a few more errors from the Ollama client to assist users
This commit is contained in:
@@ -4,6 +4,7 @@ import json
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from httpx import Client, Request, Response
|
||||
@@ -12,6 +13,8 @@ from langchain_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_call
|
||||
|
||||
MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
class TestChatOllama(ChatModelUnitTests):
|
||||
@property
|
||||
@@ -49,7 +52,7 @@ def test_arbitrary_roles_accepted_in_chatmessages(
|
||||
|
||||
llm = ChatOllama(
|
||||
base_url="http://whocares:11434",
|
||||
model="granite3.2",
|
||||
model=MODEL_NAME,
|
||||
verbose=True,
|
||||
format=None,
|
||||
)
|
||||
@@ -64,3 +67,20 @@ def test_arbitrary_roles_accepted_in_chatmessages(
|
||||
]
|
||||
|
||||
llm.invoke(messages)
|
||||
|
||||
|
||||
@patch("langchain_ollama.chat_models.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
ChatOllama(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
ChatOllama(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
ChatOllama(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
@@ -1,8 +1,30 @@
|
||||
"""Test embedding model integration."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_ollama.embeddings import OllamaEmbeddings
|
||||
|
||||
MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
OllamaEmbeddings(model="llama3", keep_alive=1)
|
||||
|
||||
|
||||
@patch("langchain_ollama.embeddings.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
OllamaEmbeddings(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
"""Test Ollama Chat API wrapper."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_ollama import OllamaLLM
|
||||
|
||||
MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test integration initialization."""
|
||||
@@ -26,3 +31,20 @@ def test_model_params() -> None:
|
||||
"ls_model_name": "llama3",
|
||||
"ls_max_tokens": 3,
|
||||
}
|
||||
|
||||
|
||||
@patch("langchain_ollama.llms.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
OllamaLLM(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
OllamaLLM(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
OllamaLLM(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user