mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 08:29:28 +00:00
implement
This commit is contained in:
parent
fee695ce6d
commit
80971b69d0
@ -15,6 +15,7 @@ service.
|
|||||||
from importlib import metadata
|
from importlib import metadata
|
||||||
|
|
||||||
from langchain_ollama.chat_models import ChatOllama
|
from langchain_ollama.chat_models import ChatOllama
|
||||||
|
from langchain_ollama.chat_models_v1 import ChatOllamaV1
|
||||||
from langchain_ollama.embeddings import OllamaEmbeddings
|
from langchain_ollama.embeddings import OllamaEmbeddings
|
||||||
from langchain_ollama.llms import OllamaLLM
|
from langchain_ollama.llms import OllamaLLM
|
||||||
|
|
||||||
@ -27,6 +28,7 @@ del metadata # optional, avoids polluting the results of dir(__package__)
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ChatOllama",
|
"ChatOllama",
|
||||||
|
"ChatOllamaV1",
|
||||||
"OllamaEmbeddings",
|
"OllamaEmbeddings",
|
||||||
"OllamaLLM",
|
"OllamaLLM",
|
||||||
"__version__",
|
"__version__",
|
||||||
|
266
libs/partners/ollama/langchain_ollama/_compat.py
Normal file
266
libs/partners/ollama/langchain_ollama/_compat.py
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
"""V1 message conversion utilities for Ollama."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from langchain_core.messages import content_blocks as types
|
||||||
|
from langchain_core.messages.content_blocks import (
|
||||||
|
ImageContentBlock,
|
||||||
|
ReasoningContentBlock,
|
||||||
|
TextContentBlock,
|
||||||
|
ToolCall,
|
||||||
|
)
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
|
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||||
|
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||||
|
from langchain_core.messages.v1 import MessageV1, ResponseMetadata
|
||||||
|
from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
|
||||||
|
from langchain_core.messages.v1 import ToolMessage as ToolMessageV1
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_from_v1_to_ollama_format(message: MessageV1) -> dict[str, Any]:
|
||||||
|
"""Convert v1 message to Ollama API format."""
|
||||||
|
if isinstance(message, HumanMessageV1):
|
||||||
|
return _convert_human_message_v1(message)
|
||||||
|
if isinstance(message, AIMessageV1):
|
||||||
|
return _convert_ai_message_v1(message)
|
||||||
|
if isinstance(message, SystemMessageV1):
|
||||||
|
return _convert_system_message_v1(message)
|
||||||
|
if isinstance(message, ToolMessageV1):
|
||||||
|
return _convert_tool_message_v1(message)
|
||||||
|
msg = f"Unsupported message type: {type(message)}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_content_blocks_to_ollama_format(
|
||||||
|
content: list[types.ContentBlock],
|
||||||
|
) -> tuple[str, list[str], list[dict[str, Any]]]:
|
||||||
|
"""Convert v1 content blocks to Ollama API format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (text_content, images, tool_calls)
|
||||||
|
"""
|
||||||
|
text_content = ""
|
||||||
|
images = []
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
for block in content:
|
||||||
|
block_type = block.get("type")
|
||||||
|
if block_type == "text":
|
||||||
|
text_block = cast(TextContentBlock, block)
|
||||||
|
text_content += text_block["text"]
|
||||||
|
elif block_type == "image":
|
||||||
|
image_block = cast(ImageContentBlock, block)
|
||||||
|
if image_block.get("source_type") == "base64":
|
||||||
|
images.append(image_block.get("data", ""))
|
||||||
|
else:
|
||||||
|
msg = "Only base64 image data is supported by Ollama"
|
||||||
|
raise ValueError(msg)
|
||||||
|
elif block_type == "audio":
|
||||||
|
msg = "Audio content blocks are not supported by Ollama"
|
||||||
|
raise ValueError(msg)
|
||||||
|
elif block_type == "tool_call":
|
||||||
|
tool_call_block = cast(ToolCall, block)
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"id": tool_call_block["id"],
|
||||||
|
"function": {
|
||||||
|
"name": tool_call_block["name"],
|
||||||
|
"arguments": tool_call_block["args"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Skip other content block types that aren't supported
|
||||||
|
|
||||||
|
return text_content, images, tool_calls # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_human_message_v1(message: HumanMessageV1) -> dict[str, Any]:
|
||||||
|
"""Convert HumanMessageV1 to Ollama format."""
|
||||||
|
text_content, images, _ = _convert_content_blocks_to_ollama_format(message.content)
|
||||||
|
|
||||||
|
msg: dict[str, Any] = {
|
||||||
|
"role": "user",
|
||||||
|
"content": text_content,
|
||||||
|
"images": images,
|
||||||
|
}
|
||||||
|
if message.name:
|
||||||
|
# Ollama doesn't have direct name support, include in content
|
||||||
|
msg["content"] = f"[{message.name}]: {text_content}"
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_ai_message_v1(message: AIMessageV1) -> dict[str, Any]:
|
||||||
|
"""Convert AIMessageV1 to Ollama format."""
|
||||||
|
text_content, _, tool_calls = _convert_content_blocks_to_ollama_format(
|
||||||
|
message.content
|
||||||
|
)
|
||||||
|
|
||||||
|
msg: dict[str, Any] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": text_content,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
msg["tool_calls"] = tool_calls
|
||||||
|
|
||||||
|
if message.name:
|
||||||
|
# Ollama doesn't have direct name support, include in content
|
||||||
|
msg["content"] = f"[{message.name}]: {text_content}"
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_system_message_v1(message: SystemMessageV1) -> dict[str, Any]:
|
||||||
|
"""Convert SystemMessageV1 to Ollama format."""
|
||||||
|
text_content, _, _ = _convert_content_blocks_to_ollama_format(message.content)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": "system",
|
||||||
|
"content": text_content,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_tool_message_v1(message: ToolMessageV1) -> dict[str, Any]:
|
||||||
|
"""Convert ToolMessageV1 to Ollama format."""
|
||||||
|
text_content, _, _ = _convert_content_blocks_to_ollama_format(message.content)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"content": text_content,
|
||||||
|
"tool_call_id": message.tool_call_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_v1_from_ollama_format(response: dict[str, Any]) -> AIMessageV1:
|
||||||
|
"""Convert Ollama API response to AIMessageV1."""
|
||||||
|
content: list[types.ContentBlock] = []
|
||||||
|
|
||||||
|
# Handle text content
|
||||||
|
if "message" in response and "content" in response["message"]:
|
||||||
|
text_content = response["message"]["content"]
|
||||||
|
if text_content:
|
||||||
|
content.append(TextContentBlock(type="text", text=text_content))
|
||||||
|
|
||||||
|
# Handle reasoning content first (should come before main response)
|
||||||
|
if "message" in response and "thinking" in response["message"]:
|
||||||
|
thinking_content = response["message"]["thinking"]
|
||||||
|
if thinking_content:
|
||||||
|
content.append(
|
||||||
|
ReasoningContentBlock(
|
||||||
|
type="reasoning",
|
||||||
|
reasoning=thinking_content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle tool calls
|
||||||
|
if "message" in response and "tool_calls" in response["message"]:
|
||||||
|
tool_calls = response["message"]["tool_calls"]
|
||||||
|
content.extend(
|
||||||
|
[
|
||||||
|
ToolCall(
|
||||||
|
type="tool_call",
|
||||||
|
id=tool_call.get("id", str(uuid4())),
|
||||||
|
name=tool_call["function"]["name"],
|
||||||
|
args=tool_call["function"]["arguments"],
|
||||||
|
)
|
||||||
|
for tool_call in tool_calls
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build response metadata
|
||||||
|
response_metadata = ResponseMetadata()
|
||||||
|
if "model" in response:
|
||||||
|
response_metadata["model_name"] = response["model"]
|
||||||
|
if "created_at" in response:
|
||||||
|
response_metadata["created_at"] = response["created_at"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "done" in response:
|
||||||
|
response_metadata["done"] = response["done"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "done_reason" in response:
|
||||||
|
response_metadata["done_reason"] = response["done_reason"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "total_duration" in response:
|
||||||
|
response_metadata["total_duration"] = response["total_duration"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "load_duration" in response:
|
||||||
|
response_metadata["load_duration"] = response["load_duration"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "prompt_eval_count" in response:
|
||||||
|
response_metadata["prompt_eval_count"] = response["prompt_eval_count"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "prompt_eval_duration" in response:
|
||||||
|
response_metadata["prompt_eval_duration"] = response["prompt_eval_duration"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "eval_count" in response:
|
||||||
|
response_metadata["eval_count"] = response["eval_count"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "eval_duration" in response:
|
||||||
|
response_metadata["eval_duration"] = response["eval_duration"] # type: ignore[typeddict-unknown-key]
|
||||||
|
|
||||||
|
return AIMessageV1(
|
||||||
|
content=content,
|
||||||
|
response_metadata=response_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_chunk_to_v1(chunk: dict[str, Any]) -> AIMessageChunkV1:
|
||||||
|
"""Convert Ollama streaming chunk to AIMessageChunkV1."""
|
||||||
|
content: list[types.ContentBlock] = []
|
||||||
|
|
||||||
|
# Handle reasoning content first in chunks
|
||||||
|
if "message" in chunk and "thinking" in chunk["message"]:
|
||||||
|
thinking_content = chunk["message"]["thinking"]
|
||||||
|
if thinking_content:
|
||||||
|
content.append(
|
||||||
|
ReasoningContentBlock(
|
||||||
|
type="reasoning",
|
||||||
|
reasoning=thinking_content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle streaming text content
|
||||||
|
if "message" in chunk and "content" in chunk["message"]:
|
||||||
|
text_content = chunk["message"]["content"]
|
||||||
|
if text_content:
|
||||||
|
content.append(TextContentBlock(type="text", text=text_content))
|
||||||
|
|
||||||
|
# Handle streaming tool calls
|
||||||
|
if "message" in chunk and "tool_calls" in chunk["message"]:
|
||||||
|
tool_calls = chunk["message"]["tool_calls"]
|
||||||
|
content.extend(
|
||||||
|
[
|
||||||
|
ToolCall(
|
||||||
|
type="tool_call",
|
||||||
|
id=tool_call.get("id", str(uuid4())),
|
||||||
|
name=tool_call.get("function", {}).get("name", ""),
|
||||||
|
args=tool_call.get("function", {}).get("arguments", {}),
|
||||||
|
)
|
||||||
|
for tool_call in tool_calls
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build response metadata for final chunks
|
||||||
|
response_metadata = None
|
||||||
|
if chunk.get("done") is True:
|
||||||
|
response_metadata = ResponseMetadata()
|
||||||
|
if "model" in chunk:
|
||||||
|
response_metadata["model_name"] = chunk["model"]
|
||||||
|
if "created_at" in chunk:
|
||||||
|
response_metadata["created_at"] = chunk["created_at"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "done_reason" in chunk:
|
||||||
|
response_metadata["done_reason"] = chunk["done_reason"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "total_duration" in chunk:
|
||||||
|
response_metadata["total_duration"] = chunk["total_duration"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "load_duration" in chunk:
|
||||||
|
response_metadata["load_duration"] = chunk["load_duration"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "prompt_eval_count" in chunk:
|
||||||
|
response_metadata["prompt_eval_count"] = chunk["prompt_eval_count"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "prompt_eval_duration" in chunk:
|
||||||
|
response_metadata["prompt_eval_duration"] = chunk["prompt_eval_duration"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "eval_count" in chunk:
|
||||||
|
response_metadata["eval_count"] = chunk["eval_count"] # type: ignore[typeddict-unknown-key]
|
||||||
|
if "eval_duration" in chunk:
|
||||||
|
response_metadata["eval_duration"] = chunk["eval_duration"] # type: ignore[typeddict-unknown-key]
|
||||||
|
|
||||||
|
return AIMessageChunkV1(
|
||||||
|
content=content,
|
||||||
|
response_metadata=response_metadata or ResponseMetadata(),
|
||||||
|
)
|
443
libs/partners/ollama/langchain_ollama/chat_models_v1.py
Normal file
443
libs/partners/ollama/langchain_ollama/chat_models_v1.py
Normal file
@ -0,0 +1,443 @@
|
|||||||
|
"""Ollama chat model v1 implementation.
|
||||||
|
|
||||||
|
This implementation provides native support for v1 messages with structured
|
||||||
|
content blocks and always returns AIMessageV1 format responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
|
from typing import Any, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
|
from langchain_core.language_models.chat_models import LangSmithParams
|
||||||
|
from langchain_core.language_models.v1.chat_models import BaseChatModelV1
|
||||||
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
|
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||||
|
from langchain_core.messages.v1 import MessageV1
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
from ollama import AsyncClient, Client, Options
|
||||||
|
from pydantic import PrivateAttr, model_validator
|
||||||
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from ._compat import (
|
||||||
|
_convert_chunk_to_v1,
|
||||||
|
_convert_from_v1_to_ollama_format,
|
||||||
|
_convert_to_v1_from_ollama_format,
|
||||||
|
)
|
||||||
|
from ._utils import validate_model
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_usage_metadata_from_response(
|
||||||
|
response: dict[str, Any],
|
||||||
|
) -> Optional[UsageMetadata]:
|
||||||
|
"""Extract usage metadata from Ollama response."""
|
||||||
|
input_tokens = response.get("prompt_eval_count")
|
||||||
|
output_tokens = response.get("eval_count")
|
||||||
|
if input_tokens is not None and output_tokens is not None:
|
||||||
|
return UsageMetadata(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
total_tokens=input_tokens + output_tokens,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChatOllamaV1(BaseChatModelV1):
|
||||||
|
"""Base class for Ollama v1 chat models."""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
"""Model name to use."""
|
||||||
|
|
||||||
|
reasoning: Optional[bool] = None
|
||||||
|
"""Controls the reasoning/thinking mode for supported models.
|
||||||
|
|
||||||
|
- ``True``: Enables reasoning mode. The model's reasoning process will be
|
||||||
|
captured and returned as a ``ReasoningContentBlock`` in the response
|
||||||
|
message content. The main response content will not include the reasoning tags.
|
||||||
|
- ``False``: Disables reasoning mode. The model will not perform any reasoning,
|
||||||
|
and the response will not include any reasoning content.
|
||||||
|
- ``None`` (Default): The model will use its default reasoning behavior. Note
|
||||||
|
however, if the model's default behavior *is* to perform reasoning, think tags
|
||||||
|
(``<think>`` and ``</think>``) will be present within the main response content
|
||||||
|
unless you set ``reasoning`` to ``True``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
validate_model_on_init: bool = False
|
||||||
|
"""Whether to validate the model exists in Ollama locally on initialization."""
|
||||||
|
|
||||||
|
# Ollama-specific parameters
|
||||||
|
mirostat: Optional[int] = None
|
||||||
|
"""Enable Mirostat sampling for controlling perplexity."""
|
||||||
|
|
||||||
|
mirostat_eta: Optional[float] = None
|
||||||
|
"""Influences how quickly the algorithm responds to feedback."""
|
||||||
|
|
||||||
|
mirostat_tau: Optional[float] = None
|
||||||
|
"""Controls the balance between coherence and diversity."""
|
||||||
|
|
||||||
|
num_ctx: Optional[int] = None
|
||||||
|
"""Sets the size of the context window."""
|
||||||
|
|
||||||
|
num_gpu: Optional[int] = None
|
||||||
|
"""The number of GPUs to use."""
|
||||||
|
|
||||||
|
num_thread: Optional[int] = None
|
||||||
|
"""Sets the number of threads to use during computation."""
|
||||||
|
|
||||||
|
num_predict: Optional[int] = None
|
||||||
|
"""Maximum number of tokens to predict."""
|
||||||
|
|
||||||
|
repeat_last_n: Optional[int] = None
|
||||||
|
"""Sets how far back for the model to look back to prevent repetition."""
|
||||||
|
|
||||||
|
repeat_penalty: Optional[float] = None
|
||||||
|
"""Sets how strongly to penalize repetitions."""
|
||||||
|
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
"""The temperature of the model."""
|
||||||
|
|
||||||
|
seed: Optional[int] = None
|
||||||
|
"""Sets the random number seed to use for generation."""
|
||||||
|
|
||||||
|
stop: Optional[list[str]] = None
|
||||||
|
"""Sets the stop tokens to use."""
|
||||||
|
|
||||||
|
tfs_z: Optional[float] = None
|
||||||
|
"""Tail free sampling parameter."""
|
||||||
|
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
"""Reduces the probability of generating nonsense."""
|
||||||
|
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
"""Works together with top-k."""
|
||||||
|
|
||||||
|
format: Optional[Union[Literal["", "json"], JsonSchemaValue]] = None
|
||||||
|
"""Specify the format of the output."""
|
||||||
|
|
||||||
|
keep_alive: Optional[Union[int, str]] = None
|
||||||
|
"""How long the model will stay loaded into memory."""
|
||||||
|
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
"""Base url the model is hosted under."""
|
||||||
|
|
||||||
|
client_kwargs: Optional[dict] = {}
|
||||||
|
"""Additional kwargs to pass to the httpx clients."""
|
||||||
|
|
||||||
|
async_client_kwargs: Optional[dict] = {}
|
||||||
|
"""Additional kwargs for the async httpx client."""
|
||||||
|
|
||||||
|
sync_client_kwargs: Optional[dict] = {}
|
||||||
|
"""Additional kwargs for the sync httpx client."""
|
||||||
|
|
||||||
|
_client: Client = PrivateAttr()
|
||||||
|
_async_client: AsyncClient = PrivateAttr()
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _set_clients(self) -> Self:
|
||||||
|
"""Set clients to use for ollama."""
|
||||||
|
client_kwargs = self.client_kwargs or {}
|
||||||
|
|
||||||
|
sync_client_kwargs = client_kwargs
|
||||||
|
if self.sync_client_kwargs:
|
||||||
|
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
|
||||||
|
|
||||||
|
async_client_kwargs = client_kwargs
|
||||||
|
if self.async_client_kwargs:
|
||||||
|
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
|
||||||
|
|
||||||
|
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
||||||
|
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
||||||
|
if self.validate_model_on_init:
|
||||||
|
validate_model(self._client, self.model)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _get_ls_params(
|
||||||
|
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||||
|
) -> LangSmithParams:
|
||||||
|
"""Get standard params for tracing."""
|
||||||
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
|
ls_params = LangSmithParams(
|
||||||
|
ls_provider="ollama",
|
||||||
|
ls_model_name=self.model,
|
||||||
|
ls_model_type="chat",
|
||||||
|
ls_temperature=params.get("temperature", self.temperature),
|
||||||
|
)
|
||||||
|
if ls_stop := stop or params.get("stop", None) or self.stop:
|
||||||
|
ls_params["ls_stop"] = ls_stop
|
||||||
|
return ls_params
|
||||||
|
|
||||||
|
def _get_invocation_params(
|
||||||
|
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Get parameters for model invocation."""
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"mirostat": self.mirostat,
|
||||||
|
"mirostat_eta": self.mirostat_eta,
|
||||||
|
"mirostat_tau": self.mirostat_tau,
|
||||||
|
"num_ctx": self.num_ctx,
|
||||||
|
"num_gpu": self.num_gpu,
|
||||||
|
"num_thread": self.num_thread,
|
||||||
|
"num_predict": self.num_predict,
|
||||||
|
"repeat_last_n": self.repeat_last_n,
|
||||||
|
"repeat_penalty": self.repeat_penalty,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"seed": self.seed,
|
||||||
|
"stop": stop or self.stop,
|
||||||
|
"tfs_z": self.tfs_z,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"format": self.format,
|
||||||
|
"keep_alive": self.keep_alive,
|
||||||
|
}
|
||||||
|
params.update(kwargs)
|
||||||
|
return params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "chat-ollama-v1"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatOllamaV1(BaseChatOllamaV1):
|
||||||
|
"""Ollama chat model with native v1 content block support.
|
||||||
|
|
||||||
|
This implementation provides native support for structured content blocks
|
||||||
|
and always returns AIMessageV1 format responses.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Basic text conversation:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_ollama import ChatOllamaV1
|
||||||
|
from langchain_core.messages.v1 import HumanMessage
|
||||||
|
from langchain_core.messages.content_blocks import TextContentBlock
|
||||||
|
|
||||||
|
llm = ChatOllamaV1(model="llama3")
|
||||||
|
response = llm.invoke([
|
||||||
|
HumanMessage(content=[
|
||||||
|
TextContentBlock(type="text", text="Hello!")
|
||||||
|
])
|
||||||
|
])
|
||||||
|
|
||||||
|
# Response is always structured
|
||||||
|
print(response.content)
|
||||||
|
# [{"type": "text", "text": "Hello! How can I help?"}]
|
||||||
|
|
||||||
|
Multi-modal input:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core.messages.content_blocks import ImageContentBlock
|
||||||
|
|
||||||
|
response = llm.invoke([
|
||||||
|
HumanMessage(content=[
|
||||||
|
TextContentBlock(type="text", text="Describe this image:"),
|
||||||
|
ImageContentBlock(
|
||||||
|
type="image",
|
||||||
|
mime_type="image/jpeg",
|
||||||
|
data="base64_encoded_image",
|
||||||
|
source_type="base64"
|
||||||
|
)
|
||||||
|
])
|
||||||
|
])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _chat_params(
|
||||||
|
self,
|
||||||
|
messages: list[MessageV1],
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build parameters for Ollama chat API."""
|
||||||
|
# Convert v1 messages to Ollama format
|
||||||
|
ollama_messages = [_convert_from_v1_to_ollama_format(msg) for msg in messages]
|
||||||
|
|
||||||
|
if self.stop is not None and stop is not None:
|
||||||
|
msg = "`stop` found in both the input and default params."
|
||||||
|
raise ValueError(msg)
|
||||||
|
if self.stop is not None:
|
||||||
|
stop = self.stop
|
||||||
|
|
||||||
|
options_dict = kwargs.pop(
|
||||||
|
"options",
|
||||||
|
{
|
||||||
|
"mirostat": self.mirostat,
|
||||||
|
"mirostat_eta": self.mirostat_eta,
|
||||||
|
"mirostat_tau": self.mirostat_tau,
|
||||||
|
"num_ctx": self.num_ctx,
|
||||||
|
"num_gpu": self.num_gpu,
|
||||||
|
"num_thread": self.num_thread,
|
||||||
|
"num_predict": self.num_predict,
|
||||||
|
"repeat_last_n": self.repeat_last_n,
|
||||||
|
"repeat_penalty": self.repeat_penalty,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"seed": self.seed,
|
||||||
|
"stop": self.stop if stop is None else stop,
|
||||||
|
"tfs_z": self.tfs_z,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"messages": ollama_messages,
|
||||||
|
"stream": kwargs.pop("stream", True),
|
||||||
|
"model": kwargs.pop("model", self.model),
|
||||||
|
"think": kwargs.pop("reasoning", self.reasoning),
|
||||||
|
"format": kwargs.pop("format", self.format),
|
||||||
|
"options": Options(**options_dict),
|
||||||
|
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools := kwargs.get("tools"):
|
||||||
|
params["tools"] = tools
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _generate_stream(
|
||||||
|
self,
|
||||||
|
messages: list[MessageV1],
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[AIMessageChunkV1]:
|
||||||
|
"""Generate streaming response with native v1 chunks."""
|
||||||
|
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||||
|
|
||||||
|
if chat_params["stream"]:
|
||||||
|
for part in self._client.chat(**chat_params):
|
||||||
|
if not isinstance(part, str):
|
||||||
|
# Skip empty load responses
|
||||||
|
if (
|
||||||
|
part.get("done") is True
|
||||||
|
and part.get("done_reason") == "load"
|
||||||
|
and not part.get("message", {}).get("content", "").strip()
|
||||||
|
):
|
||||||
|
log.warning(
|
||||||
|
"Ollama returned empty response with done_reason='load'. "
|
||||||
|
"Skipping this response."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = _convert_chunk_to_v1(part)
|
||||||
|
|
||||||
|
# Add usage metadata for final chunks
|
||||||
|
if part.get("done") is True:
|
||||||
|
usage_metadata = _get_usage_metadata_from_response(part)
|
||||||
|
if usage_metadata:
|
||||||
|
chunk.usage_metadata = usage_metadata
|
||||||
|
|
||||||
|
if run_manager:
|
||||||
|
text_content = "".join(
|
||||||
|
str(block.get("text", ""))
|
||||||
|
for block in chunk.content
|
||||||
|
if block.get("type") == "text"
|
||||||
|
)
|
||||||
|
run_manager.on_llm_new_token(
|
||||||
|
text_content,
|
||||||
|
chunk=chunk,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
# Non-streaming case
|
||||||
|
response = self._client.chat(**chat_params)
|
||||||
|
ai_message = _convert_to_v1_from_ollama_format(response)
|
||||||
|
usage_metadata = _get_usage_metadata_from_response(response)
|
||||||
|
if usage_metadata:
|
||||||
|
ai_message.usage_metadata = usage_metadata
|
||||||
|
# Convert to chunk for yielding
|
||||||
|
chunk = AIMessageChunkV1(
|
||||||
|
content=ai_message.content,
|
||||||
|
response_metadata=ai_message.response_metadata,
|
||||||
|
usage_metadata=ai_message.usage_metadata,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _agenerate_stream(
|
||||||
|
self,
|
||||||
|
messages: list[MessageV1],
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[AIMessageChunkV1]:
|
||||||
|
"""Generate async streaming response with native v1 chunks."""
|
||||||
|
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||||
|
|
||||||
|
if chat_params["stream"]:
|
||||||
|
async for part in await self._async_client.chat(**chat_params):
|
||||||
|
if not isinstance(part, str):
|
||||||
|
# Skip empty load responses
|
||||||
|
if (
|
||||||
|
part.get("done") is True
|
||||||
|
and part.get("done_reason") == "load"
|
||||||
|
and not part.get("message", {}).get("content", "").strip()
|
||||||
|
):
|
||||||
|
log.warning(
|
||||||
|
"Ollama returned empty response with done_reason='load'. "
|
||||||
|
"Skipping this response."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = _convert_chunk_to_v1(part)
|
||||||
|
|
||||||
|
# Add usage metadata for final chunks
|
||||||
|
if part.get("done") is True:
|
||||||
|
usage_metadata = _get_usage_metadata_from_response(part)
|
||||||
|
if usage_metadata:
|
||||||
|
chunk.usage_metadata = usage_metadata
|
||||||
|
|
||||||
|
if run_manager:
|
||||||
|
text_content = "".join(
|
||||||
|
str(block.get("text", ""))
|
||||||
|
for block in chunk.content
|
||||||
|
if block.get("type") == "text"
|
||||||
|
)
|
||||||
|
await run_manager.on_llm_new_token(
|
||||||
|
text_content,
|
||||||
|
chunk=chunk,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
# Non-streaming case
|
||||||
|
response = await self._async_client.chat(**chat_params)
|
||||||
|
ai_message = _convert_to_v1_from_ollama_format(response)
|
||||||
|
usage_metadata = _get_usage_metadata_from_response(response)
|
||||||
|
if usage_metadata:
|
||||||
|
ai_message.usage_metadata = usage_metadata
|
||||||
|
# Convert to chunk for yielding
|
||||||
|
chunk = AIMessageChunkV1(
|
||||||
|
content=ai_message.content,
|
||||||
|
response_metadata=ai_message.response_metadata,
|
||||||
|
usage_metadata=ai_message.usage_metadata,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||||
|
*,
|
||||||
|
tool_choice: Optional[Union[dict, str, bool]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, AIMessageV1]:
|
||||||
|
"""Bind tool-like objects to this chat model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: A list of tool definitions to bind to this chat model.
|
||||||
|
tool_choice: Tool choice parameter (currently ignored by Ollama).
|
||||||
|
kwargs: Additional parameters passed to bind().
|
||||||
|
"""
|
||||||
|
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||||
|
return super().bind(tools=formatted_tools, **kwargs)
|
174
libs/partners/ollama/tests/unit_tests/test_base_v1.py
Normal file
174
libs/partners/ollama/tests/unit_tests/test_base_v1.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
"""Unit tests for ChatOllamaV1."""
|
||||||
|
|
||||||
|
from langchain_core.messages.content_blocks import ImageContentBlock, TextContentBlock
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
|
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||||
|
from langchain_core.messages.v1 import MessageV1
|
||||||
|
from langchain_core.messages.v1 import SystemMessage as SystemMessageV1
|
||||||
|
|
||||||
|
from langchain_ollama._compat import (
|
||||||
|
_convert_chunk_to_v1,
|
||||||
|
_convert_from_v1_to_ollama_format,
|
||||||
|
_convert_to_v1_from_ollama_format,
|
||||||
|
)
|
||||||
|
from langchain_ollama.chat_models_v1 import ChatOllamaV1
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageConversion:
|
||||||
|
"""Test v1 message conversion utilities."""
|
||||||
|
|
||||||
|
def test_convert_human_message_v1_text_only(self) -> None:
|
||||||
|
"""Test converting HumanMessageV1 with text content."""
|
||||||
|
message = HumanMessageV1(
|
||||||
|
content=[TextContentBlock(type="text", text="Hello world")]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _convert_from_v1_to_ollama_format(message)
|
||||||
|
|
||||||
|
assert result["role"] == "user"
|
||||||
|
assert result["content"] == "Hello world"
|
||||||
|
assert result["images"] == []
|
||||||
|
|
||||||
|
def test_convert_human_message_v1_with_image(self) -> None:
|
||||||
|
"""Test converting HumanMessageV1 with text and image content."""
|
||||||
|
message = HumanMessageV1(
|
||||||
|
content=[
|
||||||
|
TextContentBlock(type="text", text="Describe this image:"),
|
||||||
|
ImageContentBlock( # type: ignore[typeddict-unknown-key]
|
||||||
|
type="image",
|
||||||
|
mime_type="image/jpeg",
|
||||||
|
data="base64imagedata",
|
||||||
|
source_type="base64",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _convert_from_v1_to_ollama_format(message)
|
||||||
|
|
||||||
|
assert result["role"] == "user"
|
||||||
|
assert result["content"] == "Describe this image:"
|
||||||
|
assert result["images"] == ["base64imagedata"]
|
||||||
|
|
||||||
|
def test_convert_ai_message_v1(self) -> None:
|
||||||
|
"""Test converting AIMessageV1 with text content."""
|
||||||
|
message = AIMessageV1(
|
||||||
|
content=[TextContentBlock(type="text", text="Hello! How can I help?")]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _convert_from_v1_to_ollama_format(message)
|
||||||
|
|
||||||
|
assert result["role"] == "assistant"
|
||||||
|
assert result["content"] == "Hello! How can I help?"
|
||||||
|
|
||||||
|
def test_convert_system_message_v1(self) -> None:
|
||||||
|
"""Test converting SystemMessageV1."""
|
||||||
|
message = SystemMessageV1(
|
||||||
|
content=[TextContentBlock(type="text", text="You are a helpful assistant.")]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _convert_from_v1_to_ollama_format(message)
|
||||||
|
|
||||||
|
assert result["role"] == "system"
|
||||||
|
assert result["content"] == "You are a helpful assistant."
|
||||||
|
|
||||||
|
def test_convert_from_ollama_format(self) -> None:
|
||||||
|
"""Test converting Ollama response to AIMessageV1."""
|
||||||
|
ollama_response = {
|
||||||
|
"model": "llama3",
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! How can I help you today?",
|
||||||
|
},
|
||||||
|
"done": True,
|
||||||
|
"done_reason": "stop",
|
||||||
|
"total_duration": 1000000,
|
||||||
|
"prompt_eval_count": 10,
|
||||||
|
"eval_count": 20,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = _convert_to_v1_from_ollama_format(ollama_response)
|
||||||
|
|
||||||
|
assert isinstance(result, AIMessageV1)
|
||||||
|
assert len(result.content) == 1
|
||||||
|
assert result.content[0]["type"] == "text"
|
||||||
|
assert result.content[0]["text"] == "Hello! How can I help you today?"
|
||||||
|
assert result.response_metadata["model_name"] == "llama3"
|
||||||
|
assert result.response_metadata.get("done") is True # type: ignore[typeddict-item]
|
||||||
|
|
||||||
|
def test_convert_chunk_to_v1(self) -> None:
|
||||||
|
"""Test converting Ollama streaming chunk to AIMessageChunkV1."""
|
||||||
|
chunk = {
|
||||||
|
"model": "llama3",
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
"message": {"role": "assistant", "content": "Hello"},
|
||||||
|
"done": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = _convert_chunk_to_v1(chunk)
|
||||||
|
|
||||||
|
assert len(result.content) == 1
|
||||||
|
assert result.content[0]["type"] == "text"
|
||||||
|
assert result.content[0]["text"] == "Hello"
|
||||||
|
|
||||||
|
def test_convert_empty_content(self) -> None:
|
||||||
|
"""Test converting empty content blocks."""
|
||||||
|
message = HumanMessageV1(content=[])
|
||||||
|
|
||||||
|
result = _convert_from_v1_to_ollama_format(message)
|
||||||
|
|
||||||
|
assert result["role"] == "user"
|
||||||
|
assert result["content"] == ""
|
||||||
|
assert result["images"] == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatOllamaV1:
|
||||||
|
"""Test ChatOllamaV1 class."""
|
||||||
|
|
||||||
|
def test_initialization(self) -> None:
|
||||||
|
"""Test ChatOllamaV1 initialization."""
|
||||||
|
llm = ChatOllamaV1(model="llama3")
|
||||||
|
|
||||||
|
assert llm.model == "llama3"
|
||||||
|
assert llm._llm_type == "chat-ollama-v1"
|
||||||
|
|
||||||
|
def test_chat_params(self) -> None:
|
||||||
|
"""Test _chat_params method."""
|
||||||
|
llm = ChatOllamaV1(model="llama3", temperature=0.7)
|
||||||
|
|
||||||
|
messages: list[MessageV1] = [
|
||||||
|
HumanMessageV1(content=[TextContentBlock(type="text", text="Hello")])
|
||||||
|
]
|
||||||
|
|
||||||
|
params = llm._chat_params(messages)
|
||||||
|
|
||||||
|
assert params["model"] == "llama3"
|
||||||
|
assert len(params["messages"]) == 1
|
||||||
|
assert params["messages"][0]["role"] == "user"
|
||||||
|
assert params["messages"][0]["content"] == "Hello"
|
||||||
|
assert params["options"].temperature == 0.7
|
||||||
|
|
||||||
|
def test_ls_params(self) -> None:
|
||||||
|
"""Test LangSmith parameters."""
|
||||||
|
llm = ChatOllamaV1(model="llama3", temperature=0.5)
|
||||||
|
|
||||||
|
ls_params = llm._get_ls_params()
|
||||||
|
|
||||||
|
assert ls_params["ls_provider"] == "ollama"
|
||||||
|
assert ls_params["ls_model_name"] == "llama3"
|
||||||
|
assert ls_params["ls_model_type"] == "chat"
|
||||||
|
assert ls_params["ls_temperature"] == 0.5
|
||||||
|
|
||||||
|
def test_bind_tools_basic(self) -> None:
|
||||||
|
"""Test basic tool binding functionality."""
|
||||||
|
llm = ChatOllamaV1(model="llama3")
|
||||||
|
|
||||||
|
def test_tool(query: str) -> str:
|
||||||
|
"""A test tool."""
|
||||||
|
return f"Result for: {query}"
|
||||||
|
|
||||||
|
bound_llm = llm.bind_tools([test_tool])
|
||||||
|
|
||||||
|
# Should return a bound model
|
||||||
|
assert bound_llm is not None
|
||||||
|
# The actual tool binding logic is handled by the parent class
|
@ -3,6 +3,7 @@ from langchain_ollama import __all__
|
|||||||
EXPECTED_ALL = [
|
EXPECTED_ALL = [
|
||||||
"OllamaLLM",
|
"OllamaLLM",
|
||||||
"ChatOllama",
|
"ChatOllama",
|
||||||
|
"ChatOllamaV1",
|
||||||
"OllamaEmbeddings",
|
"OllamaEmbeddings",
|
||||||
"__version__",
|
"__version__",
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user