diff --git a/libs/partners/ollama/tests/integration_tests/test_chat_models_v1.py b/libs/partners/ollama/tests/integration_tests/test_chat_models_v1.py new file mode 100644 index 00000000000..7281c819df9 --- /dev/null +++ b/libs/partners/ollama/tests/integration_tests/test_chat_models_v1.py @@ -0,0 +1,153 @@ +"""Ollama-specific v1 chat model integration tests. + +Standard tests are handled in `test_chat_models_v1_standard.py`. + +""" + +from __future__ import annotations + +from typing import Annotated, Optional + +import pytest +from pydantic import BaseModel, Field +from typing_extensions import TypedDict + +from langchain_ollama.chat_models_v1 import ChatOllama + +DEFAULT_MODEL_NAME = "llama3.1" + + +@pytest.mark.parametrize(("method"), [("function_calling"), ("json_schema")]) +def test_structured_output(method: str) -> None: + """Test to verify structured output via tool calling and `format` parameter.""" + + class Joke(BaseModel): + """Joke to tell user.""" + + setup: str = Field(description="question to set up a joke") + punchline: str = Field(description="answer to resolve the joke") + + llm = ChatOllama(model=DEFAULT_MODEL_NAME, temperature=0) + query = "Tell me a joke about cats." + + # Pydantic + if method == "function_calling": + structured_llm = llm.with_structured_output(Joke, method="function_calling") + result = structured_llm.invoke(query) + assert isinstance(result, Joke) + + for chunk in structured_llm.stream(query): + assert isinstance(chunk, Joke) + + # JSON Schema + if method == "json_schema": + structured_llm = llm.with_structured_output( + Joke.model_json_schema(), method="json_schema" + ) + result = structured_llm.invoke(query) + assert isinstance(result, dict) + assert set(result.keys()) == {"setup", "punchline"} + + for chunk in structured_llm.stream(query): + assert isinstance(chunk, dict) + assert isinstance(chunk, dict) + assert set(chunk.keys()) == {"setup", "punchline"} + + # Typed Dict + class JokeSchema(TypedDict): + """Joke to tell user.""" + + setup: Annotated[str, "question to set up a joke"] + punchline: Annotated[str, "answer to resolve the joke"] + + structured_llm = llm.with_structured_output(JokeSchema, method="json_schema") + result = structured_llm.invoke(query) + assert isinstance(result, dict) + assert set(result.keys()) == {"setup", "punchline"} + + for chunk in structured_llm.stream(query): + assert isinstance(chunk, dict) + assert isinstance(chunk, dict) + assert set(chunk.keys()) == {"setup", "punchline"} + + +@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)]) +def test_structured_output_deeply_nested(model: str) -> None: + """Test to verify structured output with a nested objects.""" + llm = ChatOllama(model=model, temperature=0) + + class Person(BaseModel): + """Information about a person.""" + + name: Optional[str] = Field(default=None, description="The name of the person") + hair_color: Optional[str] = Field( + default=None, description="The color of the person's hair if known" + ) + height_in_meters: Optional[str] = Field( + default=None, description="Height measured in meters" + ) + + class Data(BaseModel): + """Extracted data about people.""" + + people: list[Person] + + chat = llm.with_structured_output(Data) + text = ( + "Alan Smith is 6 feet tall and has blond hair." + "Alan Poe is 3 feet tall and has grey hair." + ) + result = chat.invoke(text) + assert isinstance(result, Data) + + for chunk in chat.stream(text): + assert isinstance(chunk, Data) + + +# def test_reasoning_content_blocks() -> None: +# """Test that the model supports reasoning content blocks.""" +# llm = ChatOllama(model=DEFAULT_MODEL_NAME, temperature=0) + +# # Test with a reasoning prompt +# messages = [HumanMessage("Think step by step and solve: What is 2 + 2?")] + +# result = llm.invoke(messages) + +# # Check that we get an AIMessage with content blocks +# assert isinstance(result, AIMessage) +# assert len(result.content) > 0 + +# # For streaming, check that reasoning blocks are properly handled +# chunks = [] +# for chunk in llm.stream(messages): +# chunks.append(chunk) +# assert isinstance(chunk, AIMessageChunk) + +# assert len(chunks) > 0 + + +# def test_multimodal_support() -> None: +# """Test that the model supports image content blocks.""" +# llm = ChatOllama(model=DEFAULT_MODEL_NAME, temperature=0) + +# # Create a message with image content block +# from langchain_core.messages.content_blocks import ( +# create_image_block, +# create_text_block, +# ) + +# # Test with a simple base64 placeholder (real integration would use actual image) +# message = HumanMessage( +# content=[ +# create_text_block("Describe this image:"), +# create_image_block( +# base64="iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" # noqa: E501 +# ), +# ] +# ) + +# result = llm.invoke([message]) + +# # Check that we get a response (even if it's just acknowledging the image) +# assert isinstance(result, AIMessage) +# assert len(result.content) > 0 diff --git a/libs/partners/ollama/tests/integration_tests/test_chat_models_v1_standard.py b/libs/partners/ollama/tests/integration_tests/test_chat_models_v1_standard.py new file mode 100644 index 00000000000..5f4fb8dffff --- /dev/null +++ b/libs/partners/ollama/tests/integration_tests/test_chat_models_v1_standard.py @@ -0,0 +1,199 @@ +"""Test chat model v1 integration using standard integration tests.""" + +from unittest.mock import MagicMock, patch + +import pytest +from httpx import ConnectError +from langchain_core.messages.content_blocks import ToolCallChunk +from langchain_core.tools import tool +from langchain_core.v1.chat_models import BaseChatModel +from langchain_core.v1.messages import AIMessageChunk, HumanMessage +from langchain_tests.integration_tests.chat_models_v1 import ChatModelV1IntegrationTests +from ollama import ResponseError +from pydantic import ValidationError + +from langchain_ollama.chat_models_v1 import ChatOllama + +DEFAULT_MODEL_NAME = "llama3.1" + + +@tool +def get_current_weather(location: str) -> dict: + """Gets the current weather in a given location.""" + if "boston" in location.lower(): + return {"temperature": "15°F", "conditions": "snow"} + return {"temperature": "unknown", "conditions": "unknown"} + + +class TestChatOllamaV1(ChatModelV1IntegrationTests): + @property + def chat_model_class(self) -> type[ChatOllama]: + return ChatOllama + + @property + def chat_model_params(self) -> dict: + return {"model": DEFAULT_MODEL_NAME} + + @property + def supports_reasoning_content_blocks(self) -> bool: + """ChatOllama supports reasoning content blocks.""" + return True + + @property + def supports_image_content_blocks(self) -> bool: + """ChatOllama supports image content blocks.""" + return True + + # TODO: ensure has_tool_calling tests are run + + @property + def supports_invalid_tool_calls(self) -> bool: + """ChatOllama supports invalid tool call handling.""" + return True + + @property + def supports_non_standard_blocks(self) -> bool: + """ChatOllama does not support non-standard content blocks.""" + return False + + @property + def supports_json_mode(self) -> bool: + return True + + @property + def has_tool_choice(self) -> bool: + # TODO: update after Ollama implements + # https://github.com/ollama/ollama/blob/main/docs/openai.md + return False + + def test_tool_streaming(self, model: BaseChatModel) -> None: + """Test that the model can stream tool calls.""" + chat_model_with_tools = model.bind_tools([get_current_weather]) + + prompt = [HumanMessage("What is the weather today in Boston?")] + + # Flags and collectors for validation + tool_chunk_found = False + final_tool_calls = [] + collected_tool_chunks: list[ToolCallChunk] = [] + + # Stream the response and inspect the chunks + for chunk in chat_model_with_tools.stream(prompt): + assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type" + + if chunk.tool_call_chunks: + tool_chunk_found = True + collected_tool_chunks.extend(chunk.tool_call_chunks) + + if chunk.tool_calls: + final_tool_calls.extend(chunk.tool_calls) + + assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks." + assert len(final_tool_calls) == 1, ( + f"Expected 1 final tool call, but got {len(final_tool_calls)}" + ) + + final_tool_call = final_tool_calls[0] + assert final_tool_call["name"] == "get_current_weather" + assert final_tool_call["args"] == {"location": "Boston"} + + assert len(collected_tool_chunks) > 0 + assert collected_tool_chunks[0]["name"] == "get_current_weather" + + # The ID should be consistent across chunks that have it + tool_call_id = collected_tool_chunks[0].get("id") + assert tool_call_id is not None + assert all( + chunk.get("id") == tool_call_id + for chunk in collected_tool_chunks + if chunk.get("id") + ) + assert final_tool_call["id"] == tool_call_id + + async def test_tool_astreaming(self, model: BaseChatModel) -> None: + """Test that the model can stream tool calls asynchronously.""" + chat_model_with_tools = model.bind_tools([get_current_weather]) + + prompt = [HumanMessage("What is the weather today in Boston?")] + + # Flags and collectors for validation + tool_chunk_found = False + final_tool_calls = [] + collected_tool_chunks: list[ToolCallChunk] = [] + + # Stream the response and inspect the chunks + async for chunk in chat_model_with_tools.astream(prompt): + assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type" + + if chunk.tool_call_chunks: + tool_chunk_found = True + collected_tool_chunks.extend(chunk.tool_call_chunks) + + if chunk.tool_calls: + final_tool_calls.extend(chunk.tool_calls) + + assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks." + assert len(final_tool_calls) == 1, ( + f"Expected 1 final tool call, but got {len(final_tool_calls)}" + ) + + final_tool_call = final_tool_calls[0] + assert final_tool_call["name"] == "get_current_weather" + assert final_tool_call["args"] == {"location": "Boston"} + + assert len(collected_tool_chunks) > 0 + assert collected_tool_chunks[0]["name"] == "get_current_weather" + + # The ID should be consistent across chunks that have it + tool_call_id = collected_tool_chunks[0].get("id") + assert tool_call_id is not None + assert all( + chunk.get("id") == tool_call_id + for chunk in collected_tool_chunks + if chunk.get("id") + ) + assert final_tool_call["id"] == tool_call_id + + @pytest.mark.xfail( + reason=( + "Will sometime encounter AssertionErrors where tool responses are " + "`'3'` instead of `3`" + ) + ) + def test_tool_calling(self, model: BaseChatModel) -> None: + super().test_tool_calling(model) + + @pytest.mark.xfail( + reason=( + "Will sometime encounter AssertionErrors where tool responses are " + "`'3'` instead of `3`" + ) + ) + async def test_tool_calling_async(self, model: BaseChatModel) -> None: + await super().test_tool_calling_async(model) + + @patch("langchain_ollama.chat_models_v1.Client.list") + def test_init_model_not_found(self, mock_list: MagicMock) -> None: + """Test that a ValueError is raised when the model is not found.""" + mock_list.side_effect = ValueError("Test model not found") + with pytest.raises(ValueError) as excinfo: + ChatOllama(model="non-existent-model", validate_model_on_init=True) + assert "Test model not found" in str(excinfo.value) + + @patch("langchain_ollama.chat_models_v1.Client.list") + def test_init_connection_error(self, mock_list: MagicMock) -> None: + """Test that a ValidationError is raised on connect failure during init.""" + mock_list.side_effect = ConnectError("Test connection error") + + with pytest.raises(ValidationError) as excinfo: + ChatOllama(model="any-model", validate_model_on_init=True) + assert "Failed to connect to Ollama" in str(excinfo.value) + + @patch("langchain_ollama.chat_models_v1.Client.list") + def test_init_response_error(self, mock_list: MagicMock) -> None: + """Test that a ResponseError is raised.""" + mock_list.side_effect = ResponseError("Test response error") + + with pytest.raises(ValidationError) as excinfo: + ChatOllama(model="any-model", validate_model_on_init=True) + assert "Received an error from the Ollama API" in str(excinfo.value)