nits & namespace update for ollama

This commit is contained in:
Mason Daugherty 2025-08-06 12:19:28 -04:00
parent 821527b97a
commit e18e2c13ce
No known key found for this signature in database
9 changed files with 154 additions and 101 deletions

View File

@ -16,7 +16,6 @@ service.
from importlib import metadata
from langchain_ollama.chat_models import ChatOllama
from langchain_ollama.chat_models_v1 import ChatOllama as ChatOllamaV1
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_ollama.llms import OllamaLLM
@ -31,7 +30,6 @@ del metadata # optional, avoids polluting the results of dir(__package__)
__all__ = [
"ChatOllama",
"ChatOllamaV1",
"OllamaEmbeddings",
"OllamaLLM",
"__version__",

View File

@ -1,4 +1,4 @@
"""V1 message conversion utilities for Ollama."""
"""LangChain v1 message conversion utilities for Ollama."""
from __future__ import annotations

View File

@ -1,11 +1,11 @@
"""Utility functions for validating Ollama models."""
"""Utility function to validate Ollama models."""
from httpx import ConnectError
from ollama import Client, ResponseError
def validate_model(client: Client, model_name: str) -> None:
"""Validate that a model exists in the Ollama instance.
"""Validate that a model exists in the local Ollama instance.
Args:
client: The Ollama client.

View File

@ -0,0 +1,5 @@
from langchain_ollama.v1.chat_models import (
ChatOllama,
)
__all__ = ["ChatOllama"]

View File

@ -0,0 +1,5 @@
from langchain_ollama.v1.chat_models.base import (
ChatOllama,
)
__all__ = ["ChatOllama"]

View File

@ -1,9 +1,9 @@
"""Ollama chat model v1 implementation.
"""v1 Ollama implementation.
This implementation provides native support for v1 messages with structured
content blocks.
Provides native support for v1 messages with standard content blocks.
.. versionadded:: 1.0.0
"""
from __future__ import annotations
@ -45,12 +45,12 @@ from pydantic.json_schema import JsonSchemaValue
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self, is_typeddict
from ._compat import (
from langchain_ollama._compat import (
_convert_chunk_to_v1,
_convert_from_v1_to_ollama_format,
_convert_to_v1_from_ollama_format,
)
from ._utils import validate_model
from langchain_ollama._utils import validate_model
log = logging.getLogger(__name__)
@ -116,7 +116,7 @@ def _parse_arguments_from_tool_call(
Band-aid fix for issue in Ollama with inconsistent tool call argument structure.
Should be removed/changed if fixed upstream.
See https://github.com/ollama/ollama/issues/6155
`See #6155 <https://github.com/ollama/ollama/issues/6155>`__.
"""
if "function" not in raw_tool_call:
@ -142,12 +142,6 @@ def _parse_arguments_from_tool_call(
return parsed_arguments
# Removed from v0:
# - _get_tool_calls_from_response
# _lc_tool_call_to_openai_tool_call
# _get_image_from_data_content_block
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj)
@ -193,7 +187,7 @@ class ChatOllama(BaseChatModel):
Instantiate:
.. code-block:: python
from langchain_ollama import ChatOllama
from langchain_ollama.v1 import ChatOllama
llm = ChatOllama(
model = "llama3",
@ -209,15 +203,13 @@ class ChatOllama(BaseChatModel):
from langchain_core.messages.content_blocks import TextContentBlock
messages = [
HumanMessage(content=[
TextContentBlock(type="text", text="Hello!")
])
HumanMessage("Hello!")
]
llm.invoke(messages)
.. code-block:: python
AIMessage(content=[{'type': 'text', 'text': 'Hello! How can I help you today?'}], response_metadata={'model': 'llama3', 'created_at': '2024-07-04T03:37:50.182604Z', 'done_reason': 'stop', 'done': True, 'total_duration': 3576619666, 'load_duration': 788524916, 'prompt_eval_count': 32, 'prompt_eval_duration': 128125000, 'eval_count': 71, 'eval_duration': 2656556000}, id='run-ba48f958-6402-41a5-b461-5e250a4ebd36-0')
AIMessage(content=[{'type': 'text', 'text': 'Hello! How can I help you today?'}], ...)
Stream:
.. code-block:: python
@ -226,18 +218,16 @@ class ChatOllama(BaseChatModel):
from langchain_core.messages.content_blocks import TextContentBlock
messages = [
HumanMessage(content=[
TextContentBlock(type="text", text="Return the words Hello World!")
])
HumanMessage(Return the words Hello World!")
]
for chunk in llm.stream(messages):
print(chunk.content, end="")
.. code-block:: python
[{'type': 'text', 'text': 'Hello'}]
[{'type': 'text', 'text': ' World'}]
[{'type': 'text', 'text': '!'}]
AIMessageChunk(content=[{'type': 'text', 'text': 'Hello'}], ...)
AIMessageChunk(content=[{'type': 'text', 'text': ' World'}], ...)
AIMessageChunk(content=[{'type': 'text', 'text': '!'}], ...)
Multi-modal input:
.. code-block:: python
@ -249,7 +239,6 @@ class ChatOllama(BaseChatModel):
TextContentBlock(type="text", text="Describe this image:"),
ImageContentBlock(
type="image",
mime_type="image/jpeg",
base64="base64_encoded_image",
)
])
@ -258,7 +247,6 @@ class ChatOllama(BaseChatModel):
Tool Calling:
.. code-block:: python
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field
class Multiply(BaseModel):
@ -267,20 +255,22 @@ class ChatOllama(BaseChatModel):
llm_with_tools = llm.bind_tools([Multiply])
ans = llm_with_tools.invoke([
HumanMessage(content=[
TextContentBlock(type="text", text="What is 45*67")
])
HumanMessage("What is 45*67")
])
ans.tool_calls
.. code-block:: python
[{'name': 'Multiply',
'args': {'a': 45, 'b': 67},
'id': '420c3f3b-df10-4188-945f-eb3abdb40622',
'type': 'tool_call'}]
[
{
'name': 'Multiply',
'args': {'a': 45, 'b': 67},
'id': '420c3f3b-df10-4188-945f-eb3abdb40622',
'type': 'tool_call'
}
]
""" # noqa: E501, pylint: disable=line-too-long
""" # noqa: E501
model: str
"""Model name to use."""
@ -297,6 +287,7 @@ class ChatOllama(BaseChatModel):
however, if the model's default behavior *is* to perform reasoning, think tags
(``<think>`` and ``</think>``) will be present within the main response content
unless you set ``reasoning`` to ``True``.
"""
validate_model_on_init: bool = False
@ -305,75 +296,126 @@ class ChatOllama(BaseChatModel):
# Ollama-specific parameters
mirostat: Optional[int] = None
"""Enable Mirostat sampling for controlling perplexity.
(default: ``0``, ``0`` = disabled, ``1`` = Mirostat, ``2`` = Mirostat 2.0)"""
(Default: ``0``, ``0`` = disabled, ``1`` = Mirostat, ``2`` = Mirostat 2.0)
"""
mirostat_eta: Optional[float] = None
"""Influences how quickly the algorithm responds to feedback
from the generated text. A lower learning rate will result in
slower adjustments, while a higher learning rate will make
the algorithm more responsive. (Default: ``0.1``)"""
"""Influences how quickly the algorithm responds to feedback from generated text.
A lower learning rate will result in slower adjustments, while a higher learning
rate will make the algorithm more responsive.
(Default: ``0.1``)
"""
mirostat_tau: Optional[float] = None
"""Controls the balance between coherence and diversity
of the output. A lower value will result in more focused and
coherent text. (Default: ``5.0``)"""
"""Controls the balance between coherence and diversity of the output.
A lower value will result in more focused and coherent text.
(Default: ``5.0``)
"""
num_ctx: Optional[int] = None
"""Sets the size of the context window used to generate the
next token. (Default: ``2048``) """
"""Sets the size of the context window used to generate the next token.
(Default: ``2048``)
"""
num_gpu: Optional[int] = None
"""The number of GPUs to use. On macOS it defaults to ``1`` to
enable metal support, ``0`` to disable."""
"""The number of GPUs to use.
On macOS it defaults to ``1`` to enable metal support, ``0`` to disable.
"""
num_thread: Optional[int] = None
"""Sets the number of threads to use during computation.
By default, Ollama will detect this for optimal performance.
It is recommended to set this value to the number of physical
CPU cores your system has (as opposed to the logical number of cores)."""
By default, Ollama will detect this for optimal performance. It is recommended to
set this value to the number of physical CPU cores your system has (as opposed to
the logical number of cores).
"""
num_predict: Optional[int] = None
"""Maximum number of tokens to predict when generating text.
(Default: ``128``, ``-1`` = infinite generation, ``-2`` = fill context)"""
(Default: ``128``, ``-1`` = infinite generation, ``-2`` = fill context)
"""
repeat_last_n: Optional[int] = None
"""Sets how far back for the model to look back to prevent
repetition. (Default: ``64``, ``0`` = disabled, ``-1`` = ``num_ctx``)"""
"""Sets how far back for the model to look back to prevent repetition.
(Default: ``64``, ``0`` = disabled, ``-1`` = ``num_ctx``)
"""
repeat_penalty: Optional[float] = None
"""Sets how strongly to penalize repetitions. A higher value (e.g., ``1.5``)
will penalize repetitions more strongly, while a lower value (e.g., ``0.9``)
will be more lenient. (Default: ``1.1``)"""
"""Sets how strongly to penalize repetitions.
A higher value (e.g., ``1.5``) will penalize repetitions more strongly, while a
lower value (e.g., ``0.9``) will be more lenient.
(Default: ``1.1``)
"""
temperature: Optional[float] = None
"""The temperature of the model. Increasing the temperature will
make the model answer more creatively. (Default: ``0.8``)"""
"""The temperature of the model.
Increasing the temperature will make the model answer more creatively.
(Default: ``0.8``)"""
seed: Optional[int] = None
"""Sets the random number seed to use for generation. Setting this
to a specific number will make the model generate the same text for
the same prompt."""
"""Sets the random number seed to use for generation.
Setting this to a specific number will make the model generate the same text for the
same prompt.
"""
stop: Optional[list[str]] = None
"""Sets the stop tokens to use."""
tfs_z: Optional[float] = None
"""Tail free sampling is used to reduce the impact of less probable
tokens from the output. A higher value (e.g., ``2.0``) will reduce the
impact more, while a value of ``1.0`` disables this setting. (default: ``1``)"""
"""Tail free sampling is used to reduce the impact of less probable tokens from the output.
A higher value (e.g., ``2.0``) will reduce the impact more, while a value of ``1.0`` disables this setting.
(Default: ``1``)
""" # noqa: E501
top_k: Optional[int] = None
"""Reduces the probability of generating nonsense. A higher value (e.g. ``100``)
will give more diverse answers, while a lower value (e.g. ``10``)
will be more conservative. (Default: ``40``)"""
"""Reduces the probability of generating nonsense.
A higher value (e.g. ``100``) will give more diverse answers, while a lower value
(e.g. ``10``) will be more conservative.
(Default: ``40``)
"""
top_p: Optional[float] = None
"""Works together with top-k. A higher value (e.g., ``0.95``) will lead
to more diverse text, while a lower value (e.g., ``0.5``) will
generate more focused and conservative text. (Default: ``0.9``)"""
"""Works together with top-k.
A higher value (e.g., ``0.95``) will lead to more diverse text, while a lower value
(e.g., ``0.5``) will generate more focused and conservative text.
(Default: ``0.9``)
"""
format: Optional[Union[Literal["", "json"], JsonSchemaValue]] = None
"""Specify the format of the output (options: ``'json'``, JSON schema)."""
"""Specify the format of the output (Options: ``'json'``, JSON schema)."""
keep_alive: Optional[Union[int, str]] = None
"""How long the model will stay loaded into memory."""
@ -552,7 +594,7 @@ class ChatOllama(BaseChatModel):
and not part.get("message", {}).get("content", "").strip()
):
log.warning(
"Ollama returned empty response with done_reason='load'. "
"Ollama returned empty response with `done_reason='load'`. "
"Skipping this response."
)
continue
@ -574,7 +616,6 @@ class ChatOllama(BaseChatModel):
# Non-streaming case
response = self._client.chat(**chat_params)
ai_message = _convert_to_v1_from_ollama_format(response)
# Convert to chunk for yielding
chunk = AIMessageChunk(
content=ai_message.content,
response_metadata=ai_message.response_metadata,
@ -602,7 +643,7 @@ class ChatOllama(BaseChatModel):
and not part.get("message", {}).get("content", "").strip()
):
log.warning(
"Ollama returned empty response with done_reason='load'. "
"Ollama returned empty response with `done_reason='load'`. "
"Skipping this response."
)
continue
@ -624,7 +665,6 @@ class ChatOllama(BaseChatModel):
# Non-streaming case
response = await self._async_client.chat(**chat_params)
ai_message = _convert_to_v1_from_ollama_format(response)
# Convert to chunk for yielding
chunk = AIMessageChunk(
content=ai_message.content,
response_metadata=ai_message.response_metadata,
@ -649,6 +689,7 @@ class ChatOllama(BaseChatModel):
Returns:
Complete AI message response.
"""
stream_iter = self._generate_stream(
messages, stop=stop, run_manager=run_manager, **kwargs
@ -672,6 +713,7 @@ class ChatOllama(BaseChatModel):
Returns:
Complete AI message response.
"""
stream_iter = self._agenerate_stream(
messages, stop=stop, run_manager=run_manager, **kwargs
@ -695,6 +737,7 @@ class ChatOllama(BaseChatModel):
Yields:
AI message chunks in v1 format.
"""
yield from self._generate_stream(
messages, stop=stop, run_manager=run_manager, **kwargs
@ -717,6 +760,7 @@ class ChatOllama(BaseChatModel):
Yields:
AI message chunks in v1 format.
"""
async for chunk in self._agenerate_stream(
messages, stop=stop, run_manager=run_manager, **kwargs
@ -735,7 +779,8 @@ class ChatOllama(BaseChatModel):
Args:
tools: A list of tool definitions to bind to this chat model.
tool_choice: Tool choice parameter (currently ignored by Ollama).
kwargs: Additional parameters passed to bind().
kwargs: Additional parameters passed to ``bind()``.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)

View File

@ -12,7 +12,7 @@ import pytest
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from langchain_ollama.chat_models_v1 import ChatOllama
from langchain_ollama.v1.chat_models import ChatOllama
DEFAULT_MODEL_NAME = "llama3.1"

View File

@ -12,7 +12,7 @@ from langchain_tests.integration_tests.chat_models_v1 import ChatModelV1Integrat
from ollama import ResponseError
from pydantic import ValidationError
from langchain_ollama.chat_models_v1 import ChatOllama
from langchain_ollama.v1.chat_models import ChatOllama
DEFAULT_MODEL_NAME = "llama3.1"
@ -251,7 +251,7 @@ class TestChatOllamaV1(ChatModelV1IntegrationTests):
f"Content blocks: {[block.get('type') for block in result.content]}"
)
@patch("langchain_ollama.chat_models_v1.Client.list")
@patch("langchain_ollama.v1.chat_models.Client.list")
def test_init_model_not_found(self, mock_list: MagicMock) -> None:
"""Test that a ValueError is raised when the model is not found."""
mock_list.side_effect = ValueError("Test model not found")
@ -259,7 +259,7 @@ class TestChatOllamaV1(ChatModelV1IntegrationTests):
ChatOllama(model="non-existent-model", validate_model_on_init=True)
assert "Test model not found" in str(excinfo.value)
@patch("langchain_ollama.chat_models_v1.Client.list")
@patch("langchain_ollama.v1.chat_models.Client.list")
def test_init_connection_error(self, mock_list: MagicMock) -> None:
"""Test that a ValidationError is raised on connect failure during init."""
mock_list.side_effect = ConnectError("Test connection error")
@ -268,7 +268,7 @@ class TestChatOllamaV1(ChatModelV1IntegrationTests):
ChatOllama(model="any-model", validate_model_on_init=True)
assert "Failed to connect to Ollama" in str(excinfo.value)
@patch("langchain_ollama.chat_models_v1.Client.list")
@patch("langchain_ollama.v1.chat_models.Client.list")
def test_init_response_error(self, mock_list: MagicMock) -> None:
"""Test that a ResponseError is raised."""
mock_list.side_effect = ResponseError("Test response error")

View File

@ -20,7 +20,7 @@ from langchain_ollama._compat import (
_convert_from_v1_to_ollama_format,
_convert_to_v1_from_ollama_format,
)
from langchain_ollama.chat_models_v1 import (
from langchain_ollama.v1.chat_models import (
ChatOllama,
_parse_arguments_from_tool_call,
_parse_json_string,
@ -246,8 +246,8 @@ class TestChatOllama(ChatModelV1UnitTests):
@pytest.fixture
def model(self) -> Generator[ChatOllama, None, None]: # type: ignore[override]
"""Create a ChatOllama instance for testing."""
sync_patcher = patch("langchain_ollama.chat_models_v1.Client")
async_patcher = patch("langchain_ollama.chat_models_v1.AsyncClient")
sync_patcher = patch("langchain_ollama.v1.chat_models.Client")
async_patcher = patch("langchain_ollama.v1.chat_models.AsyncClient")
mock_sync_client_class = sync_patcher.start()
mock_async_client_class = async_patcher.start()
@ -328,8 +328,8 @@ class TestChatOllama(ChatModelV1UnitTests):
def test_initialization(self) -> None:
"""Test `ChatOllama` initialization."""
with (
patch("langchain_ollama.chat_models_v1.Client"),
patch("langchain_ollama.chat_models_v1.AsyncClient"),
patch("langchain_ollama.v1.chat_models.Client"),
patch("langchain_ollama.v1.chat_models.AsyncClient"),
):
llm = ChatOllama(model=MODEL_NAME)
@ -339,8 +339,8 @@ class TestChatOllama(ChatModelV1UnitTests):
def test_chat_params(self) -> None:
"""Test `_chat_params()`."""
with (
patch("langchain_ollama.chat_models_v1.Client"),
patch("langchain_ollama.chat_models_v1.AsyncClient"),
patch("langchain_ollama.v1.chat_models.Client"),
patch("langchain_ollama.v1.chat_models.AsyncClient"),
):
llm = ChatOllama(model=MODEL_NAME, temperature=0.7)
@ -359,8 +359,8 @@ class TestChatOllama(ChatModelV1UnitTests):
def test_ls_params(self) -> None:
"""Test LangSmith parameters."""
with (
patch("langchain_ollama.chat_models_v1.Client"),
patch("langchain_ollama.chat_models_v1.AsyncClient"),
patch("langchain_ollama.v1.chat_models.Client"),
patch("langchain_ollama.v1.chat_models.AsyncClient"),
):
llm = ChatOllama(model=MODEL_NAME, temperature=0.5)
@ -374,8 +374,8 @@ class TestChatOllama(ChatModelV1UnitTests):
def test_bind_tools_basic(self) -> None:
"""Test basic tool binding functionality."""
with (
patch("langchain_ollama.chat_models_v1.Client"),
patch("langchain_ollama.chat_models_v1.AsyncClient"),
patch("langchain_ollama.v1.chat_models.Client"),
patch("langchain_ollama.v1.chat_models.AsyncClient"),
):
llm = ChatOllama(model=MODEL_NAME)
@ -394,8 +394,8 @@ class TestChatOllama(ChatModelV1UnitTests):
# But can be added if needed in the future.
@patch("langchain_ollama.chat_models_v1.validate_model")
@patch("langchain_ollama.chat_models_v1.Client")
@patch("langchain_ollama.v1.chat_models.validate_model")
@patch("langchain_ollama.v1.chat_models.Client")
def test_validate_model_on_init(
mock_client_class: Any, mock_validate_model: Any
) -> None:
@ -501,7 +501,7 @@ def test_load_response_with_empty_content_is_skipped(
}
]
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
with patch("langchain_ollama.v1.chat_models.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = iter(load_only_response)
@ -531,7 +531,7 @@ def test_load_response_with_whitespace_content_is_skipped(
}
]
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
with patch("langchain_ollama.v1.chat_models.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = iter(load_whitespace_response)
@ -570,7 +570,7 @@ def test_load_followed_by_content_response(
},
]
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
with patch("langchain_ollama.v1.chat_models.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = iter(load_then_content_response)
@ -600,7 +600,7 @@ def test_load_response_with_actual_content_is_not_skipped(
}
]
with patch("langchain_ollama.chat_models_v1.Client") as mock_client_class:
with patch("langchain_ollama.v1.chat_models.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = iter(load_with_content_response)