From 80971b69d028678f4187989585f08a33e8f0ad2c Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Tue, 29 Jul 2025 14:34:22 -0400 Subject: [PATCH] implement --- .../ollama/langchain_ollama/__init__.py | 2 + .../ollama/langchain_ollama/_compat.py | 266 +++++++++++ .../ollama/langchain_ollama/chat_models_v1.py | 443 ++++++++++++++++++ .../ollama/tests/unit_tests/test_base_v1.py | 174 +++++++ .../ollama/tests/unit_tests/test_imports.py | 1 + 5 files changed, 886 insertions(+) create mode 100644 libs/partners/ollama/langchain_ollama/_compat.py create mode 100644 libs/partners/ollama/langchain_ollama/chat_models_v1.py create mode 100644 libs/partners/ollama/tests/unit_tests/test_base_v1.py diff --git a/libs/partners/ollama/langchain_ollama/__init__.py b/libs/partners/ollama/langchain_ollama/__init__.py index 4d9864fc6a2..d514785b7a2 100644 --- a/libs/partners/ollama/langchain_ollama/__init__.py +++ b/libs/partners/ollama/langchain_ollama/__init__.py @@ -15,6 +15,7 @@ service. from importlib import metadata from langchain_ollama.chat_models import ChatOllama +from langchain_ollama.chat_models_v1 import ChatOllamaV1 from langchain_ollama.embeddings import OllamaEmbeddings from langchain_ollama.llms import OllamaLLM @@ -27,6 +28,7 @@ del metadata # optional, avoids polluting the results of dir(__package__) __all__ = [ "ChatOllama", + "ChatOllamaV1", "OllamaEmbeddings", "OllamaLLM", "__version__", diff --git a/libs/partners/ollama/langchain_ollama/_compat.py b/libs/partners/ollama/langchain_ollama/_compat.py new file mode 100644 index 00000000000..9e1d2c235d2 --- /dev/null +++ b/libs/partners/ollama/langchain_ollama/_compat.py @@ -0,0 +1,266 @@ +"""V1 message conversion utilities for Ollama.""" + +from __future__ import annotations + +from typing import Any, cast +from uuid import uuid4 + +from langchain_core.messages import content_blocks as types +from langchain_core.messages.content_blocks import ( + ImageContentBlock, + ReasoningContentBlock, + TextContentBlock, + ToolCall, +) +from langchain_core.messages.v1 import AIMessage as AIMessageV1 +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 +from langchain_core.messages.v1 import HumanMessage as HumanMessageV1 +from langchain_core.messages.v1 import MessageV1, ResponseMetadata +from langchain_core.messages.v1 import SystemMessage as SystemMessageV1 +from langchain_core.messages.v1 import ToolMessage as ToolMessageV1 + + +def _convert_from_v1_to_ollama_format(message: MessageV1) -> dict[str, Any]: + """Convert v1 message to Ollama API format.""" + if isinstance(message, HumanMessageV1): + return _convert_human_message_v1(message) + if isinstance(message, AIMessageV1): + return _convert_ai_message_v1(message) + if isinstance(message, SystemMessageV1): + return _convert_system_message_v1(message) + if isinstance(message, ToolMessageV1): + return _convert_tool_message_v1(message) + msg = f"Unsupported message type: {type(message)}" + raise ValueError(msg) + + +def _convert_content_blocks_to_ollama_format( + content: list[types.ContentBlock], +) -> tuple[str, list[str], list[dict[str, Any]]]: + """Convert v1 content blocks to Ollama API format. + + Returns: + Tuple of (text_content, images, tool_calls) + """ + text_content = "" + images = [] + tool_calls = [] + + for block in content: + block_type = block.get("type") + if block_type == "text": + text_block = cast(TextContentBlock, block) + text_content += text_block["text"] + elif block_type == "image": + image_block = cast(ImageContentBlock, block) + if image_block.get("source_type") == "base64": + images.append(image_block.get("data", "")) + else: + msg = "Only base64 image data is supported by Ollama" + raise ValueError(msg) + elif block_type == "audio": + msg = "Audio content blocks are not supported by Ollama" + raise ValueError(msg) + elif block_type == "tool_call": + tool_call_block = cast(ToolCall, block) + tool_calls.append( + { + "type": "function", + "id": tool_call_block["id"], + "function": { + "name": tool_call_block["name"], + "arguments": tool_call_block["args"], + }, + } + ) + # Skip other content block types that aren't supported + + return text_content, images, tool_calls # type: ignore[return-value] + + +def _convert_human_message_v1(message: HumanMessageV1) -> dict[str, Any]: + """Convert HumanMessageV1 to Ollama format.""" + text_content, images, _ = _convert_content_blocks_to_ollama_format(message.content) + + msg: dict[str, Any] = { + "role": "user", + "content": text_content, + "images": images, + } + if message.name: + # Ollama doesn't have direct name support, include in content + msg["content"] = f"[{message.name}]: {text_content}" + + return msg + + +def _convert_ai_message_v1(message: AIMessageV1) -> dict[str, Any]: + """Convert AIMessageV1 to Ollama format.""" + text_content, _, tool_calls = _convert_content_blocks_to_ollama_format( + message.content + ) + + msg: dict[str, Any] = { + "role": "assistant", + "content": text_content, + } + + if tool_calls: + msg["tool_calls"] = tool_calls + + if message.name: + # Ollama doesn't have direct name support, include in content + msg["content"] = f"[{message.name}]: {text_content}" + + return msg + + +def _convert_system_message_v1(message: SystemMessageV1) -> dict[str, Any]: + """Convert SystemMessageV1 to Ollama format.""" + text_content, _, _ = _convert_content_blocks_to_ollama_format(message.content) + + return { + "role": "system", + "content": text_content, + } + + +def _convert_tool_message_v1(message: ToolMessageV1) -> dict[str, Any]: + """Convert ToolMessageV1 to Ollama format.""" + text_content, _, _ = _convert_content_blocks_to_ollama_format(message.content) + + return { + "role": "tool", + "content": text_content, + "tool_call_id": message.tool_call_id, + } + + +def _convert_to_v1_from_ollama_format(response: dict[str, Any]) -> AIMessageV1: + """Convert Ollama API response to AIMessageV1.""" + content: list[types.ContentBlock] = [] + + # Handle text content + if "message" in response and "content" in response["message"]: + text_content = response["message"]["content"] + if text_content: + content.append(TextContentBlock(type="text", text=text_content)) + + # Handle reasoning content first (should come before main response) + if "message" in response and "thinking" in response["message"]: + thinking_content = response["message"]["thinking"] + if thinking_content: + content.append( + ReasoningContentBlock( + type="reasoning", + reasoning=thinking_content, + ) + ) + + # Handle tool calls + if "message" in response and "tool_calls" in response["message"]: + tool_calls = response["message"]["tool_calls"] + content.extend( + [ + ToolCall( + type="tool_call", + id=tool_call.get("id", str(uuid4())), + name=tool_call["function"]["name"], + args=tool_call["function"]["arguments"], + ) + for tool_call in tool_calls + ] + ) + + # Build response metadata + response_metadata = ResponseMetadata() + if "model" in response: + response_metadata["model_name"] = response["model"] + if "created_at" in response: + response_metadata["created_at"] = response["created_at"] # type: ignore[typeddict-unknown-key] + if "done" in response: + response_metadata["done"] = response["done"] # type: ignore[typeddict-unknown-key] + if "done_reason" in response: + response_metadata["done_reason"] = response["done_reason"] # type: ignore[typeddict-unknown-key] + if "total_duration" in response: + response_metadata["total_duration"] = response["total_duration"] # type: ignore[typeddict-unknown-key] + if "load_duration" in response: + response_metadata["load_duration"] = response["load_duration"] # type: ignore[typeddict-unknown-key] + if "prompt_eval_count" in response: + response_metadata["prompt_eval_count"] = response["prompt_eval_count"] # type: ignore[typeddict-unknown-key] + if "prompt_eval_duration" in response: + response_metadata["prompt_eval_duration"] = response["prompt_eval_duration"] # type: ignore[typeddict-unknown-key] + if "eval_count" in response: + response_metadata["eval_count"] = response["eval_count"] # type: ignore[typeddict-unknown-key] + if "eval_duration" in response: + response_metadata["eval_duration"] = response["eval_duration"] # type: ignore[typeddict-unknown-key] + + return AIMessageV1( + content=content, + response_metadata=response_metadata, + ) + + +def _convert_chunk_to_v1(chunk: dict[str, Any]) -> AIMessageChunkV1: + """Convert Ollama streaming chunk to AIMessageChunkV1.""" + content: list[types.ContentBlock] = [] + + # Handle reasoning content first in chunks + if "message" in chunk and "thinking" in chunk["message"]: + thinking_content = chunk["message"]["thinking"] + if thinking_content: + content.append( + ReasoningContentBlock( + type="reasoning", + reasoning=thinking_content, + ) + ) + + # Handle streaming text content + if "message" in chunk and "content" in chunk["message"]: + text_content = chunk["message"]["content"] + if text_content: + content.append(TextContentBlock(type="text", text=text_content)) + + # Handle streaming tool calls + if "message" in chunk and "tool_calls" in chunk["message"]: + tool_calls = chunk["message"]["tool_calls"] + content.extend( + [ + ToolCall( + type="tool_call", + id=tool_call.get("id", str(uuid4())), + name=tool_call.get("function", {}).get("name", ""), + args=tool_call.get("function", {}).get("arguments", {}), + ) + for tool_call in tool_calls + ] + ) + + # Build response metadata for final chunks + response_metadata = None + if chunk.get("done") is True: + response_metadata = ResponseMetadata() + if "model" in chunk: + response_metadata["model_name"] = chunk["model"] + if "created_at" in chunk: + response_metadata["created_at"] = chunk["created_at"] # type: ignore[typeddict-unknown-key] + if "done_reason" in chunk: + response_metadata["done_reason"] = chunk["done_reason"] # type: ignore[typeddict-unknown-key] + if "total_duration" in chunk: + response_metadata["total_duration"] = chunk["total_duration"] # type: ignore[typeddict-unknown-key] + if "load_duration" in chunk: + response_metadata["load_duration"] = chunk["load_duration"] # type: ignore[typeddict-unknown-key] + if "prompt_eval_count" in chunk: + response_metadata["prompt_eval_count"] = chunk["prompt_eval_count"] # type: ignore[typeddict-unknown-key] + if "prompt_eval_duration" in chunk: + response_metadata["prompt_eval_duration"] = chunk["prompt_eval_duration"] # type: ignore[typeddict-unknown-key] + if "eval_count" in chunk: + response_metadata["eval_count"] = chunk["eval_count"] # type: ignore[typeddict-unknown-key] + if "eval_duration" in chunk: + response_metadata["eval_duration"] = chunk["eval_duration"] # type: ignore[typeddict-unknown-key] + + return AIMessageChunkV1( + content=content, + response_metadata=response_metadata or ResponseMetadata(), + ) diff --git a/libs/partners/ollama/langchain_ollama/chat_models_v1.py b/libs/partners/ollama/langchain_ollama/chat_models_v1.py new file mode 100644 index 00000000000..301d358ee8a --- /dev/null +++ b/libs/partners/ollama/langchain_ollama/chat_models_v1.py @@ -0,0 +1,443 @@ +"""Ollama chat model v1 implementation. + +This implementation provides native support for v1 messages with structured +content blocks and always returns AIMessageV1 format responses. +""" + +from __future__ import annotations + +import logging +from collections.abc import AsyncIterator, Iterator, Sequence +from typing import Any, Callable, Literal, Optional, Union + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun +from langchain_core.language_models import LanguageModelInput +from langchain_core.language_models.chat_models import LangSmithParams +from langchain_core.language_models.v1.chat_models import BaseChatModelV1 +from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.v1 import AIMessage as AIMessageV1 +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 +from langchain_core.messages.v1 import MessageV1 +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool +from ollama import AsyncClient, Client, Options +from pydantic import PrivateAttr, model_validator +from pydantic.json_schema import JsonSchemaValue +from typing_extensions import Self + +from ._compat import ( + _convert_chunk_to_v1, + _convert_from_v1_to_ollama_format, + _convert_to_v1_from_ollama_format, +) +from ._utils import validate_model + +log = logging.getLogger(__name__) + + +def _get_usage_metadata_from_response( + response: dict[str, Any], +) -> Optional[UsageMetadata]: + """Extract usage metadata from Ollama response.""" + input_tokens = response.get("prompt_eval_count") + output_tokens = response.get("eval_count") + if input_tokens is not None and output_tokens is not None: + return UsageMetadata( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + ) + return None + + +class BaseChatOllamaV1(BaseChatModelV1): + """Base class for Ollama v1 chat models.""" + + model: str + """Model name to use.""" + + reasoning: Optional[bool] = None + """Controls the reasoning/thinking mode for supported models. + + - ``True``: Enables reasoning mode. The model's reasoning process will be + captured and returned as a ``ReasoningContentBlock`` in the response + message content. The main response content will not include the reasoning tags. + - ``False``: Disables reasoning mode. The model will not perform any reasoning, + and the response will not include any reasoning content. + - ``None`` (Default): The model will use its default reasoning behavior. Note + however, if the model's default behavior *is* to perform reasoning, think tags + (```` and ````) will be present within the main response content + unless you set ``reasoning`` to ``True``. + """ + + validate_model_on_init: bool = False + """Whether to validate the model exists in Ollama locally on initialization.""" + + # Ollama-specific parameters + mirostat: Optional[int] = None + """Enable Mirostat sampling for controlling perplexity.""" + + mirostat_eta: Optional[float] = None + """Influences how quickly the algorithm responds to feedback.""" + + mirostat_tau: Optional[float] = None + """Controls the balance between coherence and diversity.""" + + num_ctx: Optional[int] = None + """Sets the size of the context window.""" + + num_gpu: Optional[int] = None + """The number of GPUs to use.""" + + num_thread: Optional[int] = None + """Sets the number of threads to use during computation.""" + + num_predict: Optional[int] = None + """Maximum number of tokens to predict.""" + + repeat_last_n: Optional[int] = None + """Sets how far back for the model to look back to prevent repetition.""" + + repeat_penalty: Optional[float] = None + """Sets how strongly to penalize repetitions.""" + + temperature: Optional[float] = None + """The temperature of the model.""" + + seed: Optional[int] = None + """Sets the random number seed to use for generation.""" + + stop: Optional[list[str]] = None + """Sets the stop tokens to use.""" + + tfs_z: Optional[float] = None + """Tail free sampling parameter.""" + + top_k: Optional[int] = None + """Reduces the probability of generating nonsense.""" + + top_p: Optional[float] = None + """Works together with top-k.""" + + format: Optional[Union[Literal["", "json"], JsonSchemaValue]] = None + """Specify the format of the output.""" + + keep_alive: Optional[Union[int, str]] = None + """How long the model will stay loaded into memory.""" + + base_url: Optional[str] = None + """Base url the model is hosted under.""" + + client_kwargs: Optional[dict] = {} + """Additional kwargs to pass to the httpx clients.""" + + async_client_kwargs: Optional[dict] = {} + """Additional kwargs for the async httpx client.""" + + sync_client_kwargs: Optional[dict] = {} + """Additional kwargs for the sync httpx client.""" + + _client: Client = PrivateAttr() + _async_client: AsyncClient = PrivateAttr() + + @model_validator(mode="after") + def _set_clients(self) -> Self: + """Set clients to use for ollama.""" + client_kwargs = self.client_kwargs or {} + + sync_client_kwargs = client_kwargs + if self.sync_client_kwargs: + sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs} + + async_client_kwargs = client_kwargs + if self.async_client_kwargs: + async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs} + + self._client = Client(host=self.base_url, **sync_client_kwargs) + self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) + if self.validate_model_on_init: + validate_model(self._client, self.model) + return self + + def _get_ls_params( + self, stop: Optional[list[str]] = None, **kwargs: Any + ) -> LangSmithParams: + """Get standard params for tracing.""" + params = self._get_invocation_params(stop=stop, **kwargs) + ls_params = LangSmithParams( + ls_provider="ollama", + ls_model_name=self.model, + ls_model_type="chat", + ls_temperature=params.get("temperature", self.temperature), + ) + if ls_stop := stop or params.get("stop", None) or self.stop: + ls_params["ls_stop"] = ls_stop + return ls_params + + def _get_invocation_params( + self, stop: Optional[list[str]] = None, **kwargs: Any + ) -> dict[str, Any]: + """Get parameters for model invocation.""" + params = { + "model": self.model, + "mirostat": self.mirostat, + "mirostat_eta": self.mirostat_eta, + "mirostat_tau": self.mirostat_tau, + "num_ctx": self.num_ctx, + "num_gpu": self.num_gpu, + "num_thread": self.num_thread, + "num_predict": self.num_predict, + "repeat_last_n": self.repeat_last_n, + "repeat_penalty": self.repeat_penalty, + "temperature": self.temperature, + "seed": self.seed, + "stop": stop or self.stop, + "tfs_z": self.tfs_z, + "top_k": self.top_k, + "top_p": self.top_p, + "format": self.format, + "keep_alive": self.keep_alive, + } + params.update(kwargs) + return params + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "chat-ollama-v1" + + +class ChatOllamaV1(BaseChatOllamaV1): + """Ollama chat model with native v1 content block support. + + This implementation provides native support for structured content blocks + and always returns AIMessageV1 format responses. + + Examples: + Basic text conversation: + + .. code-block:: python + + from langchain_ollama import ChatOllamaV1 + from langchain_core.messages.v1 import HumanMessage + from langchain_core.messages.content_blocks import TextContentBlock + + llm = ChatOllamaV1(model="llama3") + response = llm.invoke([ + HumanMessage(content=[ + TextContentBlock(type="text", text="Hello!") + ]) + ]) + + # Response is always structured + print(response.content) + # [{"type": "text", "text": "Hello! How can I help?"}] + + Multi-modal input: + + .. code-block:: python + + from langchain_core.messages.content_blocks import ImageContentBlock + + response = llm.invoke([ + HumanMessage(content=[ + TextContentBlock(type="text", text="Describe this image:"), + ImageContentBlock( + type="image", + mime_type="image/jpeg", + data="base64_encoded_image", + source_type="base64" + ) + ]) + ]) + """ + + def _chat_params( + self, + messages: list[MessageV1], + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Build parameters for Ollama chat API.""" + # Convert v1 messages to Ollama format + ollama_messages = [_convert_from_v1_to_ollama_format(msg) for msg in messages] + + if self.stop is not None and stop is not None: + msg = "`stop` found in both the input and default params." + raise ValueError(msg) + if self.stop is not None: + stop = self.stop + + options_dict = kwargs.pop( + "options", + { + "mirostat": self.mirostat, + "mirostat_eta": self.mirostat_eta, + "mirostat_tau": self.mirostat_tau, + "num_ctx": self.num_ctx, + "num_gpu": self.num_gpu, + "num_thread": self.num_thread, + "num_predict": self.num_predict, + "repeat_last_n": self.repeat_last_n, + "repeat_penalty": self.repeat_penalty, + "temperature": self.temperature, + "seed": self.seed, + "stop": self.stop if stop is None else stop, + "tfs_z": self.tfs_z, + "top_k": self.top_k, + "top_p": self.top_p, + }, + ) + + params = { + "messages": ollama_messages, + "stream": kwargs.pop("stream", True), + "model": kwargs.pop("model", self.model), + "think": kwargs.pop("reasoning", self.reasoning), + "format": kwargs.pop("format", self.format), + "options": Options(**options_dict), + "keep_alive": kwargs.pop("keep_alive", self.keep_alive), + **kwargs, + } + + if tools := kwargs.get("tools"): + params["tools"] = tools + + return params + + def _generate_stream( + self, + messages: list[MessageV1], + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[AIMessageChunkV1]: + """Generate streaming response with native v1 chunks.""" + chat_params = self._chat_params(messages, stop, **kwargs) + + if chat_params["stream"]: + for part in self._client.chat(**chat_params): + if not isinstance(part, str): + # Skip empty load responses + if ( + part.get("done") is True + and part.get("done_reason") == "load" + and not part.get("message", {}).get("content", "").strip() + ): + log.warning( + "Ollama returned empty response with done_reason='load'. " + "Skipping this response." + ) + continue + + chunk = _convert_chunk_to_v1(part) + + # Add usage metadata for final chunks + if part.get("done") is True: + usage_metadata = _get_usage_metadata_from_response(part) + if usage_metadata: + chunk.usage_metadata = usage_metadata + + if run_manager: + text_content = "".join( + str(block.get("text", "")) + for block in chunk.content + if block.get("type") == "text" + ) + run_manager.on_llm_new_token( + text_content, + chunk=chunk, + ) + yield chunk + else: + # Non-streaming case + response = self._client.chat(**chat_params) + ai_message = _convert_to_v1_from_ollama_format(response) + usage_metadata = _get_usage_metadata_from_response(response) + if usage_metadata: + ai_message.usage_metadata = usage_metadata + # Convert to chunk for yielding + chunk = AIMessageChunkV1( + content=ai_message.content, + response_metadata=ai_message.response_metadata, + usage_metadata=ai_message.usage_metadata, + ) + yield chunk + + async def _agenerate_stream( + self, + messages: list[MessageV1], + stop: Optional[list[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[AIMessageChunkV1]: + """Generate async streaming response with native v1 chunks.""" + chat_params = self._chat_params(messages, stop, **kwargs) + + if chat_params["stream"]: + async for part in await self._async_client.chat(**chat_params): + if not isinstance(part, str): + # Skip empty load responses + if ( + part.get("done") is True + and part.get("done_reason") == "load" + and not part.get("message", {}).get("content", "").strip() + ): + log.warning( + "Ollama returned empty response with done_reason='load'. " + "Skipping this response." + ) + continue + + chunk = _convert_chunk_to_v1(part) + + # Add usage metadata for final chunks + if part.get("done") is True: + usage_metadata = _get_usage_metadata_from_response(part) + if usage_metadata: + chunk.usage_metadata = usage_metadata + + if run_manager: + text_content = "".join( + str(block.get("text", "")) + for block in chunk.content + if block.get("type") == "text" + ) + await run_manager.on_llm_new_token( + text_content, + chunk=chunk, + ) + yield chunk + else: + # Non-streaming case + response = await self._async_client.chat(**chat_params) + ai_message = _convert_to_v1_from_ollama_format(response) + usage_metadata = _get_usage_metadata_from_response(response) + if usage_metadata: + ai_message.usage_metadata = usage_metadata + # Convert to chunk for yielding + chunk = AIMessageChunkV1( + content=ai_message.content, + response_metadata=ai_message.response_metadata, + usage_metadata=ai_message.usage_metadata, + ) + yield chunk + + def bind_tools( + self, + tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]], + *, + tool_choice: Optional[Union[dict, str, bool]] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, AIMessageV1]: + """Bind tool-like objects to this chat model. + + Args: + tools: A list of tool definitions to bind to this chat model. + tool_choice: Tool choice parameter (currently ignored by Ollama). + kwargs: Additional parameters passed to bind(). + """ + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) diff --git a/libs/partners/ollama/tests/unit_tests/test_base_v1.py b/libs/partners/ollama/tests/unit_tests/test_base_v1.py new file mode 100644 index 00000000000..592d0059bec --- /dev/null +++ b/libs/partners/ollama/tests/unit_tests/test_base_v1.py @@ -0,0 +1,174 @@ +"""Unit tests for ChatOllamaV1.""" + +from langchain_core.messages.content_blocks import ImageContentBlock, TextContentBlock +from langchain_core.messages.v1 import AIMessage as AIMessageV1 +from langchain_core.messages.v1 import HumanMessage as HumanMessageV1 +from langchain_core.messages.v1 import MessageV1 +from langchain_core.messages.v1 import SystemMessage as SystemMessageV1 + +from langchain_ollama._compat import ( + _convert_chunk_to_v1, + _convert_from_v1_to_ollama_format, + _convert_to_v1_from_ollama_format, +) +from langchain_ollama.chat_models_v1 import ChatOllamaV1 + + +class TestMessageConversion: + """Test v1 message conversion utilities.""" + + def test_convert_human_message_v1_text_only(self) -> None: + """Test converting HumanMessageV1 with text content.""" + message = HumanMessageV1( + content=[TextContentBlock(type="text", text="Hello world")] + ) + + result = _convert_from_v1_to_ollama_format(message) + + assert result["role"] == "user" + assert result["content"] == "Hello world" + assert result["images"] == [] + + def test_convert_human_message_v1_with_image(self) -> None: + """Test converting HumanMessageV1 with text and image content.""" + message = HumanMessageV1( + content=[ + TextContentBlock(type="text", text="Describe this image:"), + ImageContentBlock( # type: ignore[typeddict-unknown-key] + type="image", + mime_type="image/jpeg", + data="base64imagedata", + source_type="base64", + ), + ] + ) + + result = _convert_from_v1_to_ollama_format(message) + + assert result["role"] == "user" + assert result["content"] == "Describe this image:" + assert result["images"] == ["base64imagedata"] + + def test_convert_ai_message_v1(self) -> None: + """Test converting AIMessageV1 with text content.""" + message = AIMessageV1( + content=[TextContentBlock(type="text", text="Hello! How can I help?")] + ) + + result = _convert_from_v1_to_ollama_format(message) + + assert result["role"] == "assistant" + assert result["content"] == "Hello! How can I help?" + + def test_convert_system_message_v1(self) -> None: + """Test converting SystemMessageV1.""" + message = SystemMessageV1( + content=[TextContentBlock(type="text", text="You are a helpful assistant.")] + ) + + result = _convert_from_v1_to_ollama_format(message) + + assert result["role"] == "system" + assert result["content"] == "You are a helpful assistant." + + def test_convert_from_ollama_format(self) -> None: + """Test converting Ollama response to AIMessageV1.""" + ollama_response = { + "model": "llama3", + "created_at": "2024-01-01T00:00:00Z", + "message": { + "role": "assistant", + "content": "Hello! How can I help you today?", + }, + "done": True, + "done_reason": "stop", + "total_duration": 1000000, + "prompt_eval_count": 10, + "eval_count": 20, + } + + result = _convert_to_v1_from_ollama_format(ollama_response) + + assert isinstance(result, AIMessageV1) + assert len(result.content) == 1 + assert result.content[0]["type"] == "text" + assert result.content[0]["text"] == "Hello! How can I help you today?" + assert result.response_metadata["model_name"] == "llama3" + assert result.response_metadata.get("done") is True # type: ignore[typeddict-item] + + def test_convert_chunk_to_v1(self) -> None: + """Test converting Ollama streaming chunk to AIMessageChunkV1.""" + chunk = { + "model": "llama3", + "created_at": "2024-01-01T00:00:00Z", + "message": {"role": "assistant", "content": "Hello"}, + "done": False, + } + + result = _convert_chunk_to_v1(chunk) + + assert len(result.content) == 1 + assert result.content[0]["type"] == "text" + assert result.content[0]["text"] == "Hello" + + def test_convert_empty_content(self) -> None: + """Test converting empty content blocks.""" + message = HumanMessageV1(content=[]) + + result = _convert_from_v1_to_ollama_format(message) + + assert result["role"] == "user" + assert result["content"] == "" + assert result["images"] == [] + + +class TestChatOllamaV1: + """Test ChatOllamaV1 class.""" + + def test_initialization(self) -> None: + """Test ChatOllamaV1 initialization.""" + llm = ChatOllamaV1(model="llama3") + + assert llm.model == "llama3" + assert llm._llm_type == "chat-ollama-v1" + + def test_chat_params(self) -> None: + """Test _chat_params method.""" + llm = ChatOllamaV1(model="llama3", temperature=0.7) + + messages: list[MessageV1] = [ + HumanMessageV1(content=[TextContentBlock(type="text", text="Hello")]) + ] + + params = llm._chat_params(messages) + + assert params["model"] == "llama3" + assert len(params["messages"]) == 1 + assert params["messages"][0]["role"] == "user" + assert params["messages"][0]["content"] == "Hello" + assert params["options"].temperature == 0.7 + + def test_ls_params(self) -> None: + """Test LangSmith parameters.""" + llm = ChatOllamaV1(model="llama3", temperature=0.5) + + ls_params = llm._get_ls_params() + + assert ls_params["ls_provider"] == "ollama" + assert ls_params["ls_model_name"] == "llama3" + assert ls_params["ls_model_type"] == "chat" + assert ls_params["ls_temperature"] == 0.5 + + def test_bind_tools_basic(self) -> None: + """Test basic tool binding functionality.""" + llm = ChatOllamaV1(model="llama3") + + def test_tool(query: str) -> str: + """A test tool.""" + return f"Result for: {query}" + + bound_llm = llm.bind_tools([test_tool]) + + # Should return a bound model + assert bound_llm is not None + # The actual tool binding logic is handled by the parent class diff --git a/libs/partners/ollama/tests/unit_tests/test_imports.py b/libs/partners/ollama/tests/unit_tests/test_imports.py index a5afb37b048..b41096b08f0 100644 --- a/libs/partners/ollama/tests/unit_tests/test_imports.py +++ b/libs/partners/ollama/tests/unit_tests/test_imports.py @@ -3,6 +3,7 @@ from langchain_ollama import __all__ EXPECTED_ALL = [ "OllamaLLM", "ChatOllama", + "ChatOllamaV1", "OllamaEmbeddings", "__version__", ]