mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
add ollama integration tests
This commit is contained in:
parent
ffda528f36
commit
2f27d5b6f1
@ -0,0 +1,153 @@
|
||||
"""Ollama-specific v1 chat model integration tests.
|
||||
|
||||
Standard tests are handled in `test_chat_models_v1_standard.py`.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_ollama.chat_models_v1 import ChatOllama
|
||||
|
||||
DEFAULT_MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("method"), [("function_calling"), ("json_schema")])
|
||||
def test_structured_output(method: str) -> None:
|
||||
"""Test to verify structured output via tool calling and `format` parameter."""
|
||||
|
||||
class Joke(BaseModel):
|
||||
"""Joke to tell user."""
|
||||
|
||||
setup: str = Field(description="question to set up a joke")
|
||||
punchline: str = Field(description="answer to resolve the joke")
|
||||
|
||||
llm = ChatOllama(model=DEFAULT_MODEL_NAME, temperature=0)
|
||||
query = "Tell me a joke about cats."
|
||||
|
||||
# Pydantic
|
||||
if method == "function_calling":
|
||||
structured_llm = llm.with_structured_output(Joke, method="function_calling")
|
||||
result = structured_llm.invoke(query)
|
||||
assert isinstance(result, Joke)
|
||||
|
||||
for chunk in structured_llm.stream(query):
|
||||
assert isinstance(chunk, Joke)
|
||||
|
||||
# JSON Schema
|
||||
if method == "json_schema":
|
||||
structured_llm = llm.with_structured_output(
|
||||
Joke.model_json_schema(), method="json_schema"
|
||||
)
|
||||
result = structured_llm.invoke(query)
|
||||
assert isinstance(result, dict)
|
||||
assert set(result.keys()) == {"setup", "punchline"}
|
||||
|
||||
for chunk in structured_llm.stream(query):
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(chunk, dict)
|
||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||
|
||||
# Typed Dict
|
||||
class JokeSchema(TypedDict):
|
||||
"""Joke to tell user."""
|
||||
|
||||
setup: Annotated[str, "question to set up a joke"]
|
||||
punchline: Annotated[str, "answer to resolve the joke"]
|
||||
|
||||
structured_llm = llm.with_structured_output(JokeSchema, method="json_schema")
|
||||
result = structured_llm.invoke(query)
|
||||
assert isinstance(result, dict)
|
||||
assert set(result.keys()) == {"setup", "punchline"}
|
||||
|
||||
for chunk in structured_llm.stream(query):
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(chunk, dict)
|
||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)])
|
||||
def test_structured_output_deeply_nested(model: str) -> None:
|
||||
"""Test to verify structured output with a nested objects."""
|
||||
llm = ChatOllama(model=model, temperature=0)
|
||||
|
||||
class Person(BaseModel):
|
||||
"""Information about a person."""
|
||||
|
||||
name: Optional[str] = Field(default=None, description="The name of the person")
|
||||
hair_color: Optional[str] = Field(
|
||||
default=None, description="The color of the person's hair if known"
|
||||
)
|
||||
height_in_meters: Optional[str] = Field(
|
||||
default=None, description="Height measured in meters"
|
||||
)
|
||||
|
||||
class Data(BaseModel):
|
||||
"""Extracted data about people."""
|
||||
|
||||
people: list[Person]
|
||||
|
||||
chat = llm.with_structured_output(Data)
|
||||
text = (
|
||||
"Alan Smith is 6 feet tall and has blond hair."
|
||||
"Alan Poe is 3 feet tall and has grey hair."
|
||||
)
|
||||
result = chat.invoke(text)
|
||||
assert isinstance(result, Data)
|
||||
|
||||
for chunk in chat.stream(text):
|
||||
assert isinstance(chunk, Data)
|
||||
|
||||
|
||||
# def test_reasoning_content_blocks() -> None:
|
||||
# """Test that the model supports reasoning content blocks."""
|
||||
# llm = ChatOllama(model=DEFAULT_MODEL_NAME, temperature=0)
|
||||
|
||||
# # Test with a reasoning prompt
|
||||
# messages = [HumanMessage("Think step by step and solve: What is 2 + 2?")]
|
||||
|
||||
# result = llm.invoke(messages)
|
||||
|
||||
# # Check that we get an AIMessage with content blocks
|
||||
# assert isinstance(result, AIMessage)
|
||||
# assert len(result.content) > 0
|
||||
|
||||
# # For streaming, check that reasoning blocks are properly handled
|
||||
# chunks = []
|
||||
# for chunk in llm.stream(messages):
|
||||
# chunks.append(chunk)
|
||||
# assert isinstance(chunk, AIMessageChunk)
|
||||
|
||||
# assert len(chunks) > 0
|
||||
|
||||
|
||||
# def test_multimodal_support() -> None:
|
||||
# """Test that the model supports image content blocks."""
|
||||
# llm = ChatOllama(model=DEFAULT_MODEL_NAME, temperature=0)
|
||||
|
||||
# # Create a message with image content block
|
||||
# from langchain_core.messages.content_blocks import (
|
||||
# create_image_block,
|
||||
# create_text_block,
|
||||
# )
|
||||
|
||||
# # Test with a simple base64 placeholder (real integration would use actual image)
|
||||
# message = HumanMessage(
|
||||
# content=[
|
||||
# create_text_block("Describe this image:"),
|
||||
# create_image_block(
|
||||
# base64="iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" # noqa: E501
|
||||
# ),
|
||||
# ]
|
||||
# )
|
||||
|
||||
# result = llm.invoke([message])
|
||||
|
||||
# # Check that we get a response (even if it's just acknowledging the image)
|
||||
# assert isinstance(result, AIMessage)
|
||||
# assert len(result.content) > 0
|
@ -0,0 +1,199 @@
|
||||
"""Test chat model v1 integration using standard integration tests."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ConnectError
|
||||
from langchain_core.messages.content_blocks import ToolCallChunk
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.v1.chat_models import BaseChatModel
|
||||
from langchain_core.v1.messages import AIMessageChunk, HumanMessage
|
||||
from langchain_tests.integration_tests.chat_models_v1 import ChatModelV1IntegrationTests
|
||||
from ollama import ResponseError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langchain_ollama.chat_models_v1 import ChatOllama
|
||||
|
||||
DEFAULT_MODEL_NAME = "llama3.1"
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_weather(location: str) -> dict:
|
||||
"""Gets the current weather in a given location."""
|
||||
if "boston" in location.lower():
|
||||
return {"temperature": "15°F", "conditions": "snow"}
|
||||
return {"temperature": "unknown", "conditions": "unknown"}
|
||||
|
||||
|
||||
class TestChatOllamaV1(ChatModelV1IntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> type[ChatOllama]:
|
||||
return ChatOllama
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": DEFAULT_MODEL_NAME}
|
||||
|
||||
@property
|
||||
def supports_reasoning_content_blocks(self) -> bool:
|
||||
"""ChatOllama supports reasoning content blocks."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_image_content_blocks(self) -> bool:
|
||||
"""ChatOllama supports image content blocks."""
|
||||
return True
|
||||
|
||||
# TODO: ensure has_tool_calling tests are run
|
||||
|
||||
@property
|
||||
def supports_invalid_tool_calls(self) -> bool:
|
||||
"""ChatOllama supports invalid tool call handling."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_non_standard_blocks(self) -> bool:
|
||||
"""ChatOllama does not support non-standard content blocks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_json_mode(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def has_tool_choice(self) -> bool:
|
||||
# TODO: update after Ollama implements
|
||||
# https://github.com/ollama/ollama/blob/main/docs/openai.md
|
||||
return False
|
||||
|
||||
def test_tool_streaming(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can stream tool calls."""
|
||||
chat_model_with_tools = model.bind_tools([get_current_weather])
|
||||
|
||||
prompt = [HumanMessage("What is the weather today in Boston?")]
|
||||
|
||||
# Flags and collectors for validation
|
||||
tool_chunk_found = False
|
||||
final_tool_calls = []
|
||||
collected_tool_chunks: list[ToolCallChunk] = []
|
||||
|
||||
# Stream the response and inspect the chunks
|
||||
for chunk in chat_model_with_tools.stream(prompt):
|
||||
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
|
||||
|
||||
if chunk.tool_call_chunks:
|
||||
tool_chunk_found = True
|
||||
collected_tool_chunks.extend(chunk.tool_call_chunks)
|
||||
|
||||
if chunk.tool_calls:
|
||||
final_tool_calls.extend(chunk.tool_calls)
|
||||
|
||||
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
|
||||
assert len(final_tool_calls) == 1, (
|
||||
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
|
||||
)
|
||||
|
||||
final_tool_call = final_tool_calls[0]
|
||||
assert final_tool_call["name"] == "get_current_weather"
|
||||
assert final_tool_call["args"] == {"location": "Boston"}
|
||||
|
||||
assert len(collected_tool_chunks) > 0
|
||||
assert collected_tool_chunks[0]["name"] == "get_current_weather"
|
||||
|
||||
# The ID should be consistent across chunks that have it
|
||||
tool_call_id = collected_tool_chunks[0].get("id")
|
||||
assert tool_call_id is not None
|
||||
assert all(
|
||||
chunk.get("id") == tool_call_id
|
||||
for chunk in collected_tool_chunks
|
||||
if chunk.get("id")
|
||||
)
|
||||
assert final_tool_call["id"] == tool_call_id
|
||||
|
||||
async def test_tool_astreaming(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can stream tool calls asynchronously."""
|
||||
chat_model_with_tools = model.bind_tools([get_current_weather])
|
||||
|
||||
prompt = [HumanMessage("What is the weather today in Boston?")]
|
||||
|
||||
# Flags and collectors for validation
|
||||
tool_chunk_found = False
|
||||
final_tool_calls = []
|
||||
collected_tool_chunks: list[ToolCallChunk] = []
|
||||
|
||||
# Stream the response and inspect the chunks
|
||||
async for chunk in chat_model_with_tools.astream(prompt):
|
||||
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
|
||||
|
||||
if chunk.tool_call_chunks:
|
||||
tool_chunk_found = True
|
||||
collected_tool_chunks.extend(chunk.tool_call_chunks)
|
||||
|
||||
if chunk.tool_calls:
|
||||
final_tool_calls.extend(chunk.tool_calls)
|
||||
|
||||
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
|
||||
assert len(final_tool_calls) == 1, (
|
||||
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
|
||||
)
|
||||
|
||||
final_tool_call = final_tool_calls[0]
|
||||
assert final_tool_call["name"] == "get_current_weather"
|
||||
assert final_tool_call["args"] == {"location": "Boston"}
|
||||
|
||||
assert len(collected_tool_chunks) > 0
|
||||
assert collected_tool_chunks[0]["name"] == "get_current_weather"
|
||||
|
||||
# The ID should be consistent across chunks that have it
|
||||
tool_call_id = collected_tool_chunks[0].get("id")
|
||||
assert tool_call_id is not None
|
||||
assert all(
|
||||
chunk.get("id") == tool_call_id
|
||||
for chunk in collected_tool_chunks
|
||||
if chunk.get("id")
|
||||
)
|
||||
assert final_tool_call["id"] == tool_call_id
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Will sometime encounter AssertionErrors where tool responses are "
|
||||
"`'3'` instead of `3`"
|
||||
)
|
||||
)
|
||||
def test_tool_calling(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_calling(model)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Will sometime encounter AssertionErrors where tool responses are "
|
||||
"`'3'` instead of `3`"
|
||||
)
|
||||
)
|
||||
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
|
||||
await super().test_tool_calling_async(model)
|
||||
|
||||
@patch("langchain_ollama.chat_models_v1.Client.list")
|
||||
def test_init_model_not_found(self, mock_list: MagicMock) -> None:
|
||||
"""Test that a ValueError is raised when the model is not found."""
|
||||
mock_list.side_effect = ValueError("Test model not found")
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
ChatOllama(model="non-existent-model", validate_model_on_init=True)
|
||||
assert "Test model not found" in str(excinfo.value)
|
||||
|
||||
@patch("langchain_ollama.chat_models_v1.Client.list")
|
||||
def test_init_connection_error(self, mock_list: MagicMock) -> None:
|
||||
"""Test that a ValidationError is raised on connect failure during init."""
|
||||
mock_list.side_effect = ConnectError("Test connection error")
|
||||
|
||||
with pytest.raises(ValidationError) as excinfo:
|
||||
ChatOllama(model="any-model", validate_model_on_init=True)
|
||||
assert "Failed to connect to Ollama" in str(excinfo.value)
|
||||
|
||||
@patch("langchain_ollama.chat_models_v1.Client.list")
|
||||
def test_init_response_error(self, mock_list: MagicMock) -> None:
|
||||
"""Test that a ResponseError is raised."""
|
||||
mock_list.side_effect = ResponseError("Test response error")
|
||||
|
||||
with pytest.raises(ValidationError) as excinfo:
|
||||
ChatOllama(model="any-model", validate_model_on_init=True)
|
||||
assert "Received an error from the Ollama API" in str(excinfo.value)
|
Loading…
Reference in New Issue
Block a user