diff --git a/libs/partners/ollama/langchain_ollama/__init__.py b/libs/partners/ollama/langchain_ollama/__init__.py index 3789c6de8a4..891071b3cf4 100644 --- a/libs/partners/ollama/langchain_ollama/__init__.py +++ b/libs/partners/ollama/langchain_ollama/__init__.py @@ -14,14 +14,20 @@ service. """ from importlib import metadata +from importlib.metadata import PackageNotFoundError from langchain_ollama.chat_models import ChatOllama from langchain_ollama.embeddings import OllamaEmbeddings from langchain_ollama.llms import OllamaLLM + +def _raise_package_not_found_error() -> None: + raise PackageNotFoundError + + try: if __package__ is None: - raise metadata.PackageNotFoundError + _raise_package_not_found_error() __version__ = metadata.version(__package__) except metadata.PackageNotFoundError: # Case where package metadata is not available. diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index aebc70b7f3b..f6af547a86d 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -132,7 +132,6 @@ def _parse_arguments_from_tool_call( Should be removed/changed if fixed upstream. See https://github.com/ollama/ollama/issues/6155 - """ if "function" not in raw_tool_call: return None @@ -834,7 +833,7 @@ class ChatOllama(BaseChatModel): messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, - verbose: bool = False, # noqa: FBT001, FBT002 + verbose: bool = False, # noqa: FBT002 **kwargs: Any, ) -> ChatGenerationChunk: final_chunk = None @@ -860,7 +859,7 @@ class ChatOllama(BaseChatModel): messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - verbose: bool = False, # noqa: FBT001, FBT002 + verbose: bool = False, # noqa: FBT002 **kwargs: Any, ) -> ChatGenerationChunk: final_chunk = None diff --git a/libs/partners/ollama/langchain_ollama/llms.py b/libs/partners/ollama/langchain_ollama/llms.py index 06f91502bab..4d9fd4bc63c 100644 --- a/libs/partners/ollama/langchain_ollama/llms.py +++ b/libs/partners/ollama/langchain_ollama/llms.py @@ -368,7 +368,7 @@ class OllamaLLM(BaseLLM): prompt: str, stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - verbose: bool = False, # noqa: FBT001, FBT002 + verbose: bool = False, # noqa: FBT002 **kwargs: Any, ) -> GenerationChunk: final_chunk = None @@ -410,7 +410,7 @@ class OllamaLLM(BaseLLM): prompt: str, stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, - verbose: bool = False, # noqa: FBT001, FBT002 + verbose: bool = False, # noqa: FBT002 **kwargs: Any, ) -> GenerationChunk: final_chunk = None diff --git a/libs/partners/ollama/pyproject.toml b/libs/partners/ollama/pyproject.toml index 55f9f662d51..28832012620 100644 --- a/libs/partners/ollama/pyproject.toml +++ b/libs/partners/ollama/pyproject.toml @@ -60,15 +60,17 @@ ignore = [ "SLF001", # Private member access "UP007", # pyupgrade: non-pep604-annotation-union "UP045", # pyupgrade: non-pep604-annotation-optional - "PLR0912", - "C901", - "PLR0915", + "FIX002", # TODOs + "TD002", # TODO authors + "TC002", # Incorrect type-checking block + "TC003", # Incorrect type-checking block + "PLR0912", # Too many branches + "PLR0915", # Too many statements + "C901", # Function too complex + "FBT001", # Boolean function param - # TODO: + # TODO "ANN401", - "TC002", - "TC003", - "TRY301", ] unfixable = ["B028"] # People should intentionally tune the stacklevel @@ -91,16 +93,11 @@ asyncio_mode = "auto" [tool.ruff.lint.extend-per-file-ignores] "tests/**/*.py" = [ - "S101", # Tests need assertions - "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes - "PLR2004", - - # TODO - "ANN401", - "ARG001", - "PT011", - "FIX", - "TD", + "S101", # Tests need assertions + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "ARG001", # Unused function arguments in tests (e.g. kwargs) + "PLR2004", # Magic value in comparisons + "PT011", # `pytest.raises()` is too broad ] "scripts/*.py" = [ "INP001", # Not a package diff --git a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py index 650f77f184d..f6a08d5ba64 100644 --- a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py +++ b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py @@ -3,9 +3,16 @@ from __future__ import annotations from typing import Annotated, Optional +from unittest.mock import MagicMock, patch import pytest -from pydantic import BaseModel, Field +from httpx import ConnectError +from langchain_core.messages.ai import AIMessageChunk +from langchain_core.messages.human import HumanMessage +from langchain_core.messages.tool import ToolCallChunk +from langchain_core.tools import tool +from ollama import ResponseError +from pydantic import BaseModel, Field, ValidationError from typing_extensions import TypedDict from langchain_ollama import ChatOllama @@ -13,6 +20,43 @@ from langchain_ollama import ChatOllama DEFAULT_MODEL_NAME = "llama3.1" +@tool +def get_current_weather(location: str) -> dict: + """Gets the current weather in a given location.""" + if "boston" in location.lower(): + return {"temperature": "15°F", "conditions": "snow"} + return {"temperature": "unknown", "conditions": "unknown"} + + +@patch("langchain_ollama.chat_models.Client.list") +def test_init_model_not_found(mock_list: MagicMock) -> None: + """Test that a ValueError is raised when the model is not found.""" + mock_list.side_effect = ValueError("Test model not found") + with pytest.raises(ValueError) as excinfo: + ChatOllama(model="non-existent-model", validate_model_on_init=True) + assert "Test model not found" in str(excinfo.value) + + +@patch("langchain_ollama.chat_models.Client.list") +def test_init_connection_error(mock_list: MagicMock) -> None: + """Test that a ValidationError is raised on connect failure during init.""" + mock_list.side_effect = ConnectError("Test connection error") + + with pytest.raises(ValidationError) as excinfo: + ChatOllama(model="any-model", validate_model_on_init=True) + assert "Failed to connect to Ollama" in str(excinfo.value) + + +@patch("langchain_ollama.chat_models.Client.list") +def test_init_response_error(mock_list: MagicMock) -> None: + """Test that a ResponseError is raised.""" + mock_list.side_effect = ResponseError("Test response error") + + with pytest.raises(ValidationError) as excinfo: + ChatOllama(model="any-model", validate_model_on_init=True) + assert "Received an error from the Ollama API" in str(excinfo.value) + + @pytest.mark.parametrize(("method"), [("function_calling"), ("json_schema")]) def test_structured_output(method: str) -> None: """Test to verify structured output via tool calling and `format` parameter.""" @@ -98,3 +142,97 @@ def test_structured_output_deeply_nested(model: str) -> None: for chunk in chat.stream(text): assert isinstance(chunk, Data) + + +@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)]) +def test_tool_streaming(model: str) -> None: + """Test that the model can stream tool calls.""" + llm = ChatOllama(model=model) + chat_model_with_tools = llm.bind_tools([get_current_weather]) + + prompt = [HumanMessage("What is the weather today in Boston?")] + + # Flags and collectors for validation + tool_chunk_found = False + final_tool_calls = [] + collected_tool_chunks: list[ToolCallChunk] = [] + + # Stream the response and inspect the chunks + for chunk in chat_model_with_tools.stream(prompt): + assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type" + + if chunk.tool_call_chunks: + tool_chunk_found = True + collected_tool_chunks.extend(chunk.tool_call_chunks) + + if chunk.tool_calls: + final_tool_calls.extend(chunk.tool_calls) + + assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks." + assert len(final_tool_calls) == 1, ( + f"Expected 1 final tool call, but got {len(final_tool_calls)}" + ) + + final_tool_call = final_tool_calls[0] + assert final_tool_call["name"] == "get_current_weather" + assert final_tool_call["args"] == {"location": "Boston"} + + assert len(collected_tool_chunks) > 0 + assert collected_tool_chunks[0]["name"] == "get_current_weather" + + # The ID should be consistent across chunks that have it + tool_call_id = collected_tool_chunks[0].get("id") + assert tool_call_id is not None + assert all( + chunk.get("id") == tool_call_id + for chunk in collected_tool_chunks + if chunk.get("id") + ) + assert final_tool_call["id"] == tool_call_id + + +@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)]) +async def test_tool_astreaming(model: str) -> None: + """Test that the model can stream tool calls.""" + llm = ChatOllama(model=model) + chat_model_with_tools = llm.bind_tools([get_current_weather]) + + prompt = [HumanMessage("What is the weather today in Boston?")] + + # Flags and collectors for validation + tool_chunk_found = False + final_tool_calls = [] + collected_tool_chunks: list[ToolCallChunk] = [] + + # Stream the response and inspect the chunks + async for chunk in chat_model_with_tools.astream(prompt): + assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type" + + if chunk.tool_call_chunks: + tool_chunk_found = True + collected_tool_chunks.extend(chunk.tool_call_chunks) + + if chunk.tool_calls: + final_tool_calls.extend(chunk.tool_calls) + + assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks." + assert len(final_tool_calls) == 1, ( + f"Expected 1 final tool call, but got {len(final_tool_calls)}" + ) + + final_tool_call = final_tool_calls[0] + assert final_tool_call["name"] == "get_current_weather" + assert final_tool_call["args"] == {"location": "Boston"} + + assert len(collected_tool_chunks) > 0 + assert collected_tool_chunks[0]["name"] == "get_current_weather" + + # The ID should be consistent across chunks that have it + tool_call_id = collected_tool_chunks[0].get("id") + assert tool_call_id is not None + assert all( + chunk.get("id") == tool_call_id + for chunk in collected_tool_chunks + if chunk.get("id") + ) + assert final_tool_call["id"] == tool_call_id diff --git a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_reasoning.py b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_reasoning.py index 86ba74c0069..9d31e76bea8 100644 --- a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_reasoning.py +++ b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_reasoning.py @@ -8,9 +8,10 @@ from langchain_ollama import ChatOllama SAMPLE = "What is 3^3?" -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_stream_no_reasoning(model: str) -> None: - """Test streaming with `reasoning=False`""" +@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"]) +@pytest.mark.parametrize("use_async", [False, True]) +async def test_stream_no_reasoning(model: str, use_async: bool) -> None: + """Test streaming with ``reasoning=False``.""" llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False) messages = [ { @@ -19,12 +20,20 @@ def test_stream_no_reasoning(model: str) -> None: } ] result = None - for chunk in llm.stream(messages): - assert isinstance(chunk, BaseMessageChunk) - if result is None: - result = chunk - continue - result += chunk + if use_async: + async for chunk in llm.astream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + else: + for chunk in llm.stream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk assert isinstance(result, AIMessageChunk) assert result.content assert "" not in result.content @@ -32,33 +41,10 @@ def test_stream_no_reasoning(model: str) -> None: assert "reasoning_content" not in result.additional_kwargs -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -async def test_astream_no_reasoning(model: str) -> None: - """Test async streaming with `reasoning=False`""" - llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False) - messages = [ - { - "role": "user", - "content": SAMPLE, - } - ] - result = None - async for chunk in llm.astream(messages): - assert isinstance(chunk, BaseMessageChunk) - if result is None: - result = chunk - continue - result += chunk - assert isinstance(result, AIMessageChunk) - assert result.content - assert "" not in result.content - assert "" not in result.content - assert "reasoning_content" not in result.additional_kwargs - - -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_stream_reasoning_none(model: str) -> None: - """Test streaming with `reasoning=None`""" +@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"]) +@pytest.mark.parametrize("use_async", [False, True]) +async def test_stream_reasoning_none(model: str, use_async: bool) -> None: + """Test streaming with ``reasoning=None``.""" llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None) messages = [ { @@ -67,12 +53,20 @@ def test_stream_reasoning_none(model: str) -> None: } ] result = None - for chunk in llm.stream(messages): - assert isinstance(chunk, BaseMessageChunk) - if result is None: - result = chunk - continue - result += chunk + if use_async: + async for chunk in llm.astream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + else: + for chunk in llm.stream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk assert isinstance(result, AIMessageChunk) assert result.content assert "" in result.content @@ -82,35 +76,10 @@ def test_stream_reasoning_none(model: str) -> None: assert "" not in result.additional_kwargs.get("reasoning_content", "") -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -async def test_astream_reasoning_none(model: str) -> None: - """Test async streaming with `reasoning=None`""" - llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None) - messages = [ - { - "role": "user", - "content": SAMPLE, - } - ] - result = None - async for chunk in llm.astream(messages): - assert isinstance(chunk, BaseMessageChunk) - if result is None: - result = chunk - continue - result += chunk - assert isinstance(result, AIMessageChunk) - assert result.content - assert "" in result.content - assert "" in result.content - assert "reasoning_content" not in result.additional_kwargs - assert "" not in result.additional_kwargs.get("reasoning_content", "") - assert "" not in result.additional_kwargs.get("reasoning_content", "") - - -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_reasoning_stream(model: str) -> None: - """Test streaming with `reasoning=True`""" +@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"]) +@pytest.mark.parametrize("use_async", [False, True]) +async def test_reasoning_stream(model: str, use_async: bool) -> None: + """Test streaming with ``reasoning=True``.""" llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True) messages = [ { @@ -119,12 +88,20 @@ def test_reasoning_stream(model: str) -> None: } ] result = None - for chunk in llm.stream(messages): - assert isinstance(chunk, BaseMessageChunk) - if result is None: - result = chunk - continue - result += chunk + if use_async: + async for chunk in llm.astream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk + else: + for chunk in llm.stream(messages): + assert isinstance(chunk, BaseMessageChunk) + if result is None: + result = chunk + continue + result += chunk assert isinstance(result, AIMessageChunk) assert result.content assert "reasoning_content" in result.additional_kwargs @@ -135,63 +112,32 @@ def test_reasoning_stream(model: str) -> None: assert "" not in result.additional_kwargs["reasoning_content"] -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -async def test_reasoning_astream(model: str) -> None: - """Test async streaming with `reasoning=True`""" - llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True) - messages = [ - { - "role": "user", - "content": SAMPLE, - } - ] - result = None - async for chunk in llm.astream(messages): - assert isinstance(chunk, BaseMessageChunk) - if result is None: - result = chunk - continue - result += chunk - assert isinstance(result, AIMessageChunk) - assert result.content - assert "reasoning_content" in result.additional_kwargs - assert len(result.additional_kwargs["reasoning_content"]) > 0 - assert "" not in result.content - assert "" not in result.content - assert "" not in result.additional_kwargs["reasoning_content"] - assert "" not in result.additional_kwargs["reasoning_content"] - - -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_invoke_no_reasoning(model: str) -> None: - """Test using invoke with `reasoning=False`""" +@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"]) +@pytest.mark.parametrize("use_async", [False, True]) +async def test_invoke_no_reasoning(model: str, use_async: bool) -> None: + """Test invoke with ``reasoning=False``.""" llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False) message = HumanMessage(content=SAMPLE) - result = llm.invoke([message]) + if use_async: + result = await llm.ainvoke([message]) + else: + result = llm.invoke([message]) assert result.content assert "reasoning_content" not in result.additional_kwargs assert "" not in result.content assert "" not in result.content -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -async def test_ainvoke_no_reasoning(model: str) -> None: - """Test using async invoke with `reasoning=False`""" - llm = ChatOllama(model=model, num_ctx=2**12, reasoning=False) - message = HumanMessage(content=SAMPLE) - result = await llm.ainvoke([message]) - assert result.content - assert "reasoning_content" not in result.additional_kwargs - assert "" not in result.content - assert "" not in result.content - - -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_invoke_reasoning_none(model: str) -> None: - """Test using invoke with `reasoning=None`""" +@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"]) +@pytest.mark.parametrize("use_async", [False, True]) +async def test_invoke_reasoning_none(model: str, use_async: bool) -> None: + """Test invoke with ``reasoning=None``.""" llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None) message = HumanMessage(content=SAMPLE) - result = llm.invoke([message]) + if use_async: + result = await llm.ainvoke([message]) + else: + result = llm.invoke([message]) assert result.content assert "reasoning_content" not in result.additional_kwargs assert "" in result.content @@ -200,26 +146,16 @@ def test_invoke_reasoning_none(model: str) -> None: assert "" not in result.additional_kwargs.get("reasoning_content", "") -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -async def test_ainvoke_reasoning_none(model: str) -> None: - """Test using async invoke with `reasoning=None`""" - llm = ChatOllama(model=model, num_ctx=2**12, reasoning=None) - message = HumanMessage(content=SAMPLE) - result = await llm.ainvoke([message]) - assert result.content - assert "reasoning_content" not in result.additional_kwargs - assert "" in result.content - assert "" in result.content - assert "" not in result.additional_kwargs.get("reasoning_content", "") - assert "" not in result.additional_kwargs.get("reasoning_content", "") - - -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -def test_reasoning_invoke(model: str) -> None: - """Test invoke with `reasoning=True`""" +@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"]) +@pytest.mark.parametrize("use_async", [False, True]) +async def test_reasoning_invoke(model: str, use_async: bool) -> None: + """Test invoke with ``reasoning=True``.""" llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True) message = HumanMessage(content=SAMPLE) - result = llm.invoke([message]) + if use_async: + result = await llm.ainvoke([message]) + else: + result = llm.invoke([message]) assert result.content assert "reasoning_content" in result.additional_kwargs assert len(result.additional_kwargs["reasoning_content"]) > 0 @@ -229,22 +165,7 @@ def test_reasoning_invoke(model: str) -> None: assert "" not in result.additional_kwargs["reasoning_content"] -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) -async def test_reasoning_ainvoke(model: str) -> None: - """Test invoke with `reasoning=True`""" - llm = ChatOllama(model=model, num_ctx=2**12, reasoning=True) - message = HumanMessage(content=SAMPLE) - result = await llm.ainvoke([message]) - assert result.content - assert "reasoning_content" in result.additional_kwargs - assert len(result.additional_kwargs["reasoning_content"]) > 0 - assert "" not in result.content - assert "" not in result.content - assert "" not in result.additional_kwargs["reasoning_content"] - assert "" not in result.additional_kwargs["reasoning_content"] - - -@pytest.mark.parametrize(("model"), [("deepseek-r1:1.5b")]) +@pytest.mark.parametrize("model", ["deepseek-r1:1.5b"]) def test_think_tag_stripping_necessity(model: str) -> None: """Test that demonstrates why ``_strip_think_tags`` is necessary. diff --git a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py index 95742e924f2..984dacb96ca 100644 --- a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py +++ b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py @@ -1,29 +1,14 @@ """Test chat model integration using standard integration tests.""" -from unittest.mock import MagicMock, patch - import pytest -from httpx import ConnectError from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessageChunk, HumanMessage, ToolCallChunk -from langchain_core.tools import tool from langchain_tests.integration_tests import ChatModelIntegrationTests -from ollama import ResponseError -from pydantic import ValidationError from langchain_ollama.chat_models import ChatOllama DEFAULT_MODEL_NAME = "llama3.1" -@tool -def get_current_weather(location: str) -> dict: - """Gets the current weather in a given location.""" - if "boston" in location.lower(): - return {"temperature": "15°F", "conditions": "snow"} - return {"temperature": "unknown", "conditions": "unknown"} - - class TestChatOllama(ChatModelIntegrationTests): @property def chat_model_class(self) -> type[ChatOllama]: @@ -47,94 +32,6 @@ class TestChatOllama(ChatModelIntegrationTests): def supports_image_inputs(self) -> bool: return True - def test_tool_streaming(self, model: BaseChatModel) -> None: - """Test that the model can stream tool calls.""" - chat_model_with_tools = model.bind_tools([get_current_weather]) - - prompt = [HumanMessage("What is the weather today in Boston?")] - - # Flags and collectors for validation - tool_chunk_found = False - final_tool_calls = [] - collected_tool_chunks: list[ToolCallChunk] = [] - - # Stream the response and inspect the chunks - for chunk in chat_model_with_tools.stream(prompt): - assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type" - - if chunk.tool_call_chunks: - tool_chunk_found = True - collected_tool_chunks.extend(chunk.tool_call_chunks) - - if chunk.tool_calls: - final_tool_calls.extend(chunk.tool_calls) - - assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks." - assert len(final_tool_calls) == 1, ( - f"Expected 1 final tool call, but got {len(final_tool_calls)}" - ) - - final_tool_call = final_tool_calls[0] - assert final_tool_call["name"] == "get_current_weather" - assert final_tool_call["args"] == {"location": "Boston"} - - assert len(collected_tool_chunks) > 0 - assert collected_tool_chunks[0]["name"] == "get_current_weather" - - # The ID should be consistent across chunks that have it - tool_call_id = collected_tool_chunks[0].get("id") - assert tool_call_id is not None - assert all( - chunk.get("id") == tool_call_id - for chunk in collected_tool_chunks - if chunk.get("id") - ) - assert final_tool_call["id"] == tool_call_id - - async def test_tool_astreaming(self, model: BaseChatModel) -> None: - """Test that the model can stream tool calls.""" - chat_model_with_tools = model.bind_tools([get_current_weather]) - - prompt = [HumanMessage("What is the weather today in Boston?")] - - # Flags and collectors for validation - tool_chunk_found = False - final_tool_calls = [] - collected_tool_chunks: list[ToolCallChunk] = [] - - # Stream the response and inspect the chunks - async for chunk in chat_model_with_tools.astream(prompt): - assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type" - - if chunk.tool_call_chunks: - tool_chunk_found = True - collected_tool_chunks.extend(chunk.tool_call_chunks) - - if chunk.tool_calls: - final_tool_calls.extend(chunk.tool_calls) - - assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks." - assert len(final_tool_calls) == 1, ( - f"Expected 1 final tool call, but got {len(final_tool_calls)}" - ) - - final_tool_call = final_tool_calls[0] - assert final_tool_call["name"] == "get_current_weather" - assert final_tool_call["args"] == {"location": "Boston"} - - assert len(collected_tool_chunks) > 0 - assert collected_tool_chunks[0]["name"] == "get_current_weather" - - # The ID should be consistent across chunks that have it - tool_call_id = collected_tool_chunks[0].get("id") - assert tool_call_id is not None - assert all( - chunk.get("id") == tool_call_id - for chunk in collected_tool_chunks - if chunk.get("id") - ) - assert final_tool_call["id"] == tool_call_id - @pytest.mark.xfail( reason=( "Will sometime encounter AssertionErrors where tool responses are " @@ -153,28 +50,13 @@ class TestChatOllama(ChatModelIntegrationTests): async def test_tool_calling_async(self, model: BaseChatModel) -> None: await super().test_tool_calling_async(model) - @patch("langchain_ollama.chat_models.Client.list") - def test_init_model_not_found(self, mock_list: MagicMock) -> None: - """Test that a ValueError is raised when the model is not found.""" - mock_list.side_effect = ValueError("Test model not found") - with pytest.raises(ValueError) as excinfo: - ChatOllama(model="non-existent-model", validate_model_on_init=True) - assert "Test model not found" in str(excinfo.value) - - @patch("langchain_ollama.chat_models.Client.list") - def test_init_connection_error(self, mock_list: MagicMock) -> None: - """Test that a ValidationError is raised on connect failure during init.""" - mock_list.side_effect = ConnectError("Test connection error") - - with pytest.raises(ValidationError) as excinfo: - ChatOllama(model="any-model", validate_model_on_init=True) - assert "Failed to connect to Ollama" in str(excinfo.value) - - @patch("langchain_ollama.chat_models.Client.list") - def test_init_response_error(self, mock_list: MagicMock) -> None: - """Test that a ResponseError is raised.""" - mock_list.side_effect = ResponseError("Test response error") - - with pytest.raises(ValidationError) as excinfo: - ChatOllama(model="any-model", validate_model_on_init=True) - assert "Received an error from the Ollama API" in str(excinfo.value) + @pytest.mark.xfail( + reason=( + "Will sometimes fail due to Ollama's inconsistent tool call argument " + "structure (see https://github.com/ollama/ollama/issues/6155). " + "Args may contain unexpected keys like 'conversations' instead of " + "empty dict." + ) + ) + def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: + super().test_tool_calling_with_no_arguments(model) diff --git a/libs/partners/ollama/tests/integration_tests/test_llms.py b/libs/partners/ollama/tests/integration_tests/test_llms.py index 9989309d9ee..f307eb15076 100644 --- a/libs/partners/ollama/tests/integration_tests/test_llms.py +++ b/libs/partners/ollama/tests/integration_tests/test_llms.py @@ -13,11 +13,74 @@ REASONING_MODEL_NAME = os.environ.get("OLLAMA_REASONING_TEST_MODEL", "deepseek-r SAMPLE = "What is 3^3?" +def test_invoke() -> None: + """Test sync invoke returning a string.""" + llm = OllamaLLM(model=MODEL_NAME) + result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"])) + assert isinstance(result, str) + + +async def test_ainvoke() -> None: + """Test async invoke returning a string.""" + llm = OllamaLLM(model=MODEL_NAME) + + result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"])) + assert isinstance(result, str) + + +def test_batch() -> None: + """Test batch sync token generation from `OllamaLLM`.""" + llm = OllamaLLM(model=MODEL_NAME) + + result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) + for token in result: + assert isinstance(token, str) + + +async def test_abatch() -> None: + """Test batch async token generation from `OllamaLLM`.""" + llm = OllamaLLM(model=MODEL_NAME) + + result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) + for token in result: + assert isinstance(token, str) + + +def test_batch_tags() -> None: + """Test batch sync token generation with tags.""" + llm = OllamaLLM(model=MODEL_NAME) + + result = llm.batch( + ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} + ) + for token in result: + assert isinstance(token, str) + + +async def test_abatch_tags() -> None: + """Test batch async token generation with tags.""" + llm = OllamaLLM(model=MODEL_NAME) + + result = await llm.abatch( + ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} + ) + for token in result: + assert isinstance(token, str) + + def test_stream_text_tokens() -> None: """Test streaming raw string tokens from `OllamaLLM`.""" llm = OllamaLLM(model=MODEL_NAME) - for token in llm.stream("I'm Pickle Rick"): + for token in llm.stream("Hi."): + assert isinstance(token, str) + + +async def test_astream_text_tokens() -> None: + """Test async streaming raw string tokens from `OllamaLLM`.""" + llm = OllamaLLM(model=MODEL_NAME) + + async for token in llm.astream("Hi."): assert isinstance(token, str) @@ -28,7 +91,6 @@ def test__stream_no_reasoning(model: str) -> None: result_chunk = None for chunk in llm._stream(SAMPLE): - # Should be a GenerationChunk assert isinstance(chunk, GenerationChunk) if result_chunk is None: result_chunk = chunk @@ -38,8 +100,28 @@ def test__stream_no_reasoning(model: str) -> None: # The final result must be a GenerationChunk with visible content assert isinstance(result_chunk, GenerationChunk) assert result_chunk.text - # No separate reasoning_content - assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator] + assert result_chunk.generation_info + assert not result_chunk.generation_info.get("reasoning_content") + + +@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)]) +async def test__astream_no_reasoning(model: str) -> None: + """Test low-level async chunk streaming with `reasoning=False`.""" + llm = OllamaLLM(model=model, num_ctx=2**12) + + result_chunk = None + async for chunk in llm._astream(SAMPLE): + assert isinstance(chunk, GenerationChunk) + if result_chunk is None: + result_chunk = chunk + else: + result_chunk += chunk + + # The final result must be a GenerationChunk with visible content + assert isinstance(result_chunk, GenerationChunk) + assert result_chunk.text + assert result_chunk.generation_info + assert not result_chunk.generation_info.get("reasoning_content") @pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)]) @@ -57,40 +139,17 @@ def test__stream_with_reasoning(model: str) -> None: assert isinstance(result_chunk, GenerationChunk) assert result_chunk.text + # Should have extracted reasoning into generation_info - assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator] - assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index] + assert result_chunk.generation_info + reasoning_content = result_chunk.generation_info.get("reasoning_content") + assert reasoning_content + assert len(reasoning_content) > 0 # And neither the visible nor the hidden portion contains tags assert "" not in result_chunk.text assert "" not in result_chunk.text - assert "" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index] - assert "" not in result_chunk.generation_info["reasoning_content"] # type: ignore[index] - - -async def test_astream_text_tokens() -> None: - """Test async streaming raw string tokens from `OllamaLLM`.""" - llm = OllamaLLM(model=MODEL_NAME) - - async for token in llm.astream("I'm Pickle Rick"): - assert isinstance(token, str) - - -@pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)]) -async def test__astream_no_reasoning(model: str) -> None: - """Test low-level async chunk streaming with `reasoning=False`.""" - llm = OllamaLLM(model=model, num_ctx=2**12) - - result_chunk = None - async for chunk in llm._astream(SAMPLE): - assert isinstance(chunk, GenerationChunk) - if result_chunk is None: - result_chunk = chunk - else: - result_chunk += chunk - - assert isinstance(result_chunk, GenerationChunk) - assert result_chunk.text - assert "reasoning_content" not in result_chunk.generation_info # type: ignore[operator] + assert "" not in reasoning_content + assert "" not in reasoning_content @pytest.mark.parametrize(("model"), [(REASONING_MODEL_NAME)]) @@ -108,49 +167,14 @@ async def test__astream_with_reasoning(model: str) -> None: assert isinstance(result_chunk, GenerationChunk) assert result_chunk.text - assert "reasoning_content" in result_chunk.generation_info # type: ignore[operator] - assert len(result_chunk.generation_info["reasoning_content"]) > 0 # type: ignore[index] - -async def test_abatch() -> None: - """Test batch sync token generation from `OllamaLLM`.""" - llm = OllamaLLM(model=MODEL_NAME) - - result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token, str) - - -async def test_abatch_tags() -> None: - """Test batch sync token generation with tags.""" - llm = OllamaLLM(model=MODEL_NAME) - - result = await llm.abatch( - ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} - ) - for token in result: - assert isinstance(token, str) - - -def test_batch() -> None: - """Test batch token generation from `OllamaLLM`.""" - llm = OllamaLLM(model=MODEL_NAME) - - result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token, str) - - -async def test_ainvoke() -> None: - """Test async invoke returning a string.""" - llm = OllamaLLM(model=MODEL_NAME) - - result = await llm.ainvoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"])) - assert isinstance(result, str) - - -def test_invoke() -> None: - """Test sync invoke returning a string.""" - llm = OllamaLLM(model=MODEL_NAME) - result = llm.invoke("I'm Pickle Rick", config=RunnableConfig(tags=["foo"])) - assert isinstance(result, str) + # Should have extracted reasoning into generation_info + assert result_chunk.generation_info + reasoning_content = result_chunk.generation_info.get("reasoning_content") + assert reasoning_content + assert len(reasoning_content) > 0 + # And neither the visible nor the hidden portion contains tags + assert "" not in result_chunk.text + assert "" not in result_chunk.text + assert "" not in reasoning_content + assert "" not in reasoning_content diff --git a/libs/partners/ollama/tests/unit_tests/test_chat_models.py b/libs/partners/ollama/tests/unit_tests/test_chat_models.py index 640d7e66c8d..d5708a4937a 100644 --- a/libs/partners/ollama/tests/unit_tests/test_chat_models.py +++ b/libs/partners/ollama/tests/unit_tests/test_chat_models.py @@ -22,6 +22,22 @@ from langchain_ollama.chat_models import ( MODEL_NAME = "llama3.1" +@contextmanager +def _mock_httpx_client_stream( + *args: Any, **kwargs: Any +) -> Generator[Response, Any, Any]: + yield Response( + status_code=200, + content='{"message": {"role": "assistant", "content": "The meaning ..."}}', + request=Request(method="POST", url="http://whocares:11434"), + ) + + +dummy_raw_tool_call = { + "function": {"name": "test_func", "arguments": ""}, +} + + class TestChatOllama(ChatModelUnitTests): @property def chat_model_class(self) -> type[ChatOllama]: @@ -35,19 +51,24 @@ class TestChatOllama(ChatModelUnitTests): def test__parse_arguments_from_tool_call() -> None: """Test that string arguments are preserved as strings in tool call parsing. - This test verifies the fix for PR #30154 which addressed an issue where - string-typed tool arguments (like IDs or long strings) were being incorrectly + PR #30154 + String-typed tool arguments (like IDs or long strings) were being incorrectly processed. The parser should preserve string values as strings rather than attempting to parse them as JSON when they're already valid string arguments. - The test uses a long string ID to ensure string arguments maintain their - original type after parsing, which is critical for tools expecting string inputs. + Use a long string ID to ensure string arguments maintain their original type after + parsing, which is critical for tools expecting string inputs. """ - raw_response = '{"model":"sample-model","message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"get_profile_details","arguments":{"arg_1":"12345678901234567890123456"}}}]},"done":false}' # noqa: E501 + raw_response = ( + '{"model":"sample-model","message":{"role":"assistant","content":"",' + '"tool_calls":[{"function":{"name":"get_profile_details",' + '"arguments":{"arg_1":"12345678901234567890123456"}}}]},"done":false}' + ) raw_tool_calls = json.loads(raw_response)["message"]["tool_calls"] response = _parse_arguments_from_tool_call(raw_tool_calls[0]) assert response is not None assert isinstance(response["arg_1"], str) + assert response["arg_1"] == "12345678901234567890123456" def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None: @@ -57,7 +78,6 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None: that just echoes the function name. This should be filtered out for no-argument tools to return an empty dictionary. """ - # Test case where arguments contain functionName metadata raw_tool_call_with_metadata = { "function": { "name": "magic_function_no_args", @@ -67,7 +87,7 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None: response = _parse_arguments_from_tool_call(raw_tool_call_with_metadata) assert response == {} - # Test case where arguments contain both real args and metadata + # Arguments contain both real args and metadata raw_tool_call_mixed = { "function": { "name": "some_function", @@ -77,7 +97,7 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None: response_mixed = _parse_arguments_from_tool_call(raw_tool_call_mixed) assert response_mixed == {"real_arg": "value"} - # Test case where functionName has different value (should be preserved) + # functionName has different value (should be preserved) raw_tool_call_different = { "function": {"name": "function_a", "arguments": {"functionName": "function_b"}} } @@ -85,17 +105,6 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None: assert response_different == {"functionName": "function_b"} -@contextmanager -def _mock_httpx_client_stream( - *args: Any, **kwargs: Any -) -> Generator[Response, Any, Any]: - yield Response( - status_code=200, - content='{"message": {"role": "assistant", "content": "The meaning ..."}}', - request=Request(method="POST", url="http://whocares:11434"), - ) - - def test_arbitrary_roles_accepted_in_chatmessages( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -120,26 +129,16 @@ def test_arbitrary_roles_accepted_in_chatmessages( @patch("langchain_ollama.chat_models.validate_model") def test_validate_model_on_init(mock_validate_model: Any) -> None: """Test that the model is validated on initialization when requested.""" - # Test that validate_model is called when validate_model_on_init=True ChatOllama(model=MODEL_NAME, validate_model_on_init=True) mock_validate_model.assert_called_once() mock_validate_model.reset_mock() - # Test that validate_model is NOT called when validate_model_on_init=False ChatOllama(model=MODEL_NAME, validate_model_on_init=False) mock_validate_model.assert_not_called() - - # Test that validate_model is NOT called by default ChatOllama(model=MODEL_NAME) mock_validate_model.assert_not_called() -# Define a dummy raw_tool_call for the function signature -dummy_raw_tool_call = { - "function": {"name": "test_func", "arguments": ""}, -} - - @pytest.mark.parametrize( ("input_string", "expected_output"), [ @@ -164,7 +163,7 @@ def test_parse_json_string_success_cases( def test_parse_json_string_failure_case_raises_exception() -> None: """Tests that `_parse_json_string` raises an exception for malformed strings.""" - malformed_string = "{'key': 'value',,}" + malformed_string = "{'key': 'value',,}" # Double comma is invalid raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}} with pytest.raises(OutputParserException): _parse_json_string( @@ -181,7 +180,7 @@ def test_parse_json_string_skip_returns_input_on_failure() -> None: result = _parse_json_string( malformed_string, raw_tool_call=raw_tool_call, - skip=True, + skip=True, # We want the original invalid string back ) assert result == malformed_string diff --git a/libs/partners/ollama/tests/unit_tests/test_embeddings.py b/libs/partners/ollama/tests/unit_tests/test_embeddings.py index feca762deb6..1c7c0939cc2 100644 --- a/libs/partners/ollama/tests/unit_tests/test_embeddings.py +++ b/libs/partners/ollama/tests/unit_tests/test_embeddings.py @@ -16,16 +16,12 @@ def test_initialization() -> None: @patch("langchain_ollama.embeddings.validate_model") def test_validate_model_on_init(mock_validate_model: Any) -> None: """Test that the model is validated on initialization when requested.""" - # Test that validate_model is called when validate_model_on_init=True OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=True) mock_validate_model.assert_called_once() mock_validate_model.reset_mock() - # Test that validate_model is NOT called when validate_model_on_init=False OllamaEmbeddings(model=MODEL_NAME, validate_model_on_init=False) mock_validate_model.assert_not_called() - - # Test that validate_model is NOT called by default OllamaEmbeddings(model=MODEL_NAME) mock_validate_model.assert_not_called() @@ -33,20 +29,13 @@ def test_validate_model_on_init(mock_validate_model: Any) -> None: @patch("langchain_ollama.embeddings.Client") def test_embed_documents_passes_options(mock_client_class: Any) -> None: """Test that `embed_documents()` passes options, including `num_gpu`.""" - # Create a mock client instance mock_client = Mock() mock_client_class.return_value = mock_client - - # Mock the embed method response mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3]]} - # Create embeddings with num_gpu parameter embeddings = OllamaEmbeddings(model=MODEL_NAME, num_gpu=4, temperature=0.5) - - # Call embed_documents result = embeddings.embed_documents(["test text"]) - # Verify the result assert result == [[0.1, 0.2, 0.3]] # Check that embed was called with correct arguments diff --git a/libs/partners/ollama/tests/unit_tests/test_llms.py b/libs/partners/ollama/tests/unit_tests/test_llms.py index 55116688af3..ab49c591a50 100644 --- a/libs/partners/ollama/tests/unit_tests/test_llms.py +++ b/libs/partners/ollama/tests/unit_tests/test_llms.py @@ -14,7 +14,7 @@ def test_initialization() -> None: def test_model_params() -> None: - # Test standard tracing params + """Test standard tracing params""" llm = OllamaLLM(model=MODEL_NAME) ls_params = llm._get_ls_params() assert ls_params == { @@ -36,16 +36,12 @@ def test_model_params() -> None: @patch("langchain_ollama.llms.validate_model") def test_validate_model_on_init(mock_validate_model: Any) -> None: """Test that the model is validated on initialization when requested.""" - # Test that validate_model is called when validate_model_on_init=True OllamaLLM(model=MODEL_NAME, validate_model_on_init=True) mock_validate_model.assert_called_once() mock_validate_model.reset_mock() - # Test that validate_model is NOT called when validate_model_on_init=False OllamaLLM(model=MODEL_NAME, validate_model_on_init=False) mock_validate_model.assert_not_called() - - # Test that validate_model is NOT called by default OllamaLLM(model=MODEL_NAME) mock_validate_model.assert_not_called()