diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index c95bba60875..ffe8766b802 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -44,6 +44,7 @@ from __future__ import annotations import ast import json import logging +import warnings from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence from operator import itemgetter from typing import Any, Literal, cast @@ -83,7 +84,7 @@ from langchain_core.utils.function_calling import ( ) from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass from ollama import AsyncClient, Client, Message -from pydantic import BaseModel, PrivateAttr, model_validator +from pydantic import BaseModel, PrivateAttr, field_validator, model_validator from pydantic.json_schema import JsonSchemaValue from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self, is_typeddict @@ -626,6 +627,31 @@ class ChatOllama(BaseChatModel): same prompt. """ + logprobs: bool | None = None + """Whether to return logprobs. + + !!! note + + When streaming, per-token logprobs are available on each intermediate + chunk (via `response_metadata["logprobs"]`) and are accumulated into the + final aggregated response when using `invoke()`. + """ + + top_logprobs: int | None = None + """Number of most likely tokens to return at each token position, each with + an associated log probability. Must be a positive integer. + + If set without `logprobs=True`, `logprobs` will be enabled automatically. + """ + + @field_validator("top_logprobs") + @classmethod + def _validate_top_logprobs(cls, v: int | None) -> int | None: + if v is not None and v < 1: + msg = "`top_logprobs` must be a positive integer." + raise ValueError(msg) + return v + stop: list[str] | None = None """Sets the stop tokens to use.""" @@ -772,6 +798,8 @@ class ChatOllama(BaseChatModel): "model": kwargs.pop("model", self.model), "think": kwargs.pop("reasoning", self.reasoning), "format": kwargs.pop("format", self.format), + "logprobs": kwargs.pop("logprobs", self.logprobs), + "top_logprobs": kwargs.pop("top_logprobs", self.top_logprobs), "options": options_dict, "keep_alive": kwargs.pop("keep_alive", self.keep_alive), **kwargs, @@ -790,6 +818,23 @@ class ChatOllama(BaseChatModel): @model_validator(mode="after") def _set_clients(self) -> Self: """Set clients to use for ollama.""" + if self.top_logprobs is not None and self.logprobs is not True: + if self.logprobs is False: + msg = ( + "`top_logprobs` is set but `logprobs` is explicitly `False`. " + "Either set `logprobs=True` to use `top_logprobs`, or remove " + "`top_logprobs`." + ) + raise ValueError(msg) + # logprobs is None (unset) — auto-enable as convenience + self.logprobs = True + warnings.warn( + "`top_logprobs` is set but `logprobs` was not explicitly enabled. " + "Setting `logprobs=True` automatically.", + UserWarning, + stacklevel=2, + ) + client_kwargs = self.client_kwargs or {} cleaned_url, auth_headers = parse_url_with_auth(self.base_url) @@ -1096,7 +1141,12 @@ class ChatOllama(BaseChatModel): generation_info["model_provider"] = "ollama" _ = generation_info.pop("message", None) else: - generation_info = None + chunk_logprobs = stream_resp.get("logprobs") + generation_info = ( + {"logprobs": chunk_logprobs} + if chunk_logprobs is not None + else None + ) additional_kwargs = {} if ( @@ -1173,7 +1223,12 @@ class ChatOllama(BaseChatModel): generation_info["model_provider"] = "ollama" _ = generation_info.pop("message", None) else: - generation_info = None + chunk_logprobs = stream_resp.get("logprobs") + generation_info = ( + {"logprobs": chunk_logprobs} + if chunk_logprobs is not None + else None + ) additional_kwargs = {} if ( diff --git a/libs/partners/ollama/pyproject.toml b/libs/partners/ollama/pyproject.toml index 6af8129d5ef..5bcae24d4dd 100644 --- a/libs/partners/ollama/pyproject.toml +++ b/libs/partners/ollama/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ version = "1.0.1" requires-python = ">=3.10.0,<4.0.0" dependencies = [ - "ollama>=0.6.0,<1.0.0", + "ollama>=0.6.1,<1.0.0", "langchain-core>=1.2.21,<2.0.0", ] @@ -114,6 +114,7 @@ asyncio_mode = "auto" [tool.ruff.lint.extend-per-file-ignores] "tests/**/*.py" = [ "S101", # Tests need assertions + "S105", # False positive on dict key "token" in logprobs assertions "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes "ARG001", # Unused function arguments in tests (e.g. kwargs) "PLR2004", # Magic value in comparisons diff --git a/libs/partners/ollama/tests/unit_tests/test_chat_models.py b/libs/partners/ollama/tests/unit_tests/test_chat_models.py index f6f97d3284f..876df51927c 100644 --- a/libs/partners/ollama/tests/unit_tests/test_chat_models.py +++ b/libs/partners/ollama/tests/unit_tests/test_chat_models.py @@ -2,13 +2,11 @@ import json import logging -from collections.abc import Generator -from contextlib import contextmanager +import warnings from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from httpx import Client, Request, Response from langchain_core.exceptions import OutputParserException from langchain_core.messages import ChatMessage, HumanMessage from langchain_tests.unit_tests import ChatModelUnitTests @@ -22,17 +20,6 @@ from langchain_ollama.chat_models import ( MODEL_NAME = "llama3.1" -@contextmanager -def _mock_httpx_client_stream( - *_args: Any, **_kwargs: Any -) -> Generator[Response, Any, Any]: - yield Response( - status_code=200, - content='{"message": {"role": "assistant", "content": "The meaning ..."}}', - request=Request(method="POST", url="http://whocares:11434"), - ) - - dummy_raw_tool_call = { "function": {"name": "test_func", "arguments": ""}, } @@ -105,25 +92,37 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None: assert response_different == {"functionName": "function_b"} -def test_arbitrary_roles_accepted_in_chatmessages( - monkeypatch: pytest.MonkeyPatch, -) -> None: +def test_arbitrary_roles_accepted_in_chatmessages() -> None: """Test that `ChatOllama` accepts arbitrary roles in `ChatMessage`.""" - monkeypatch.setattr(Client, "stream", _mock_httpx_client_stream) - llm = ChatOllama( - model=MODEL_NAME, - verbose=True, - format=None, - ) - messages = [ - ChatMessage( - role="somerandomrole", - content="I'm ok with you adding any role message now!", - ), - ChatMessage(role="control", content="thinking"), - ChatMessage(role="user", content="What is the meaning of life?"), + response = [ + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "done": True, + "done_reason": "stop", + "message": {"role": "assistant", "content": "The meaning of life..."}, + } ] - llm.invoke(messages) + + with patch("langchain_ollama.chat_models.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.chat.return_value = response + + llm = ChatOllama( + model=MODEL_NAME, + verbose=True, + format=None, + ) + messages = [ + ChatMessage( + role="somerandomrole", + content="I'm ok with you adding any role message now!", + ), + ChatMessage(role="control", content="thinking"), + ChatMessage(role="user", content="What is the meaning of life?"), + ] + llm.invoke(messages) @patch("langchain_ollama.chat_models.validate_model") @@ -449,6 +448,308 @@ def test_reasoning_param_passed_to_client() -> None: assert call_kwargs["think"] is True +def test_logprobs_params_passed_to_client() -> None: + """Test that logprobs parameters are correctly passed to the Ollama client.""" + response = [ + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": "Hello!"}, + "done": True, + "done_reason": "stop", + } + ] + + with patch("langchain_ollama.chat_models.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.chat.return_value = response + + # Case 1: logprobs=True, top_logprobs=5 in init + llm = ChatOllama(model=MODEL_NAME, logprobs=True, top_logprobs=5) + llm.invoke([HumanMessage("Hello")]) + + call_kwargs = mock_client.chat.call_args[1] + assert call_kwargs["logprobs"] is True + assert call_kwargs["top_logprobs"] == 5 + + # Case 2: override via invoke kwargs + llm = ChatOllama(model=MODEL_NAME) + llm.invoke([HumanMessage("Hello")], logprobs=True, top_logprobs=3) + + call_kwargs = mock_client.chat.call_args[1] + assert call_kwargs["logprobs"] is True + assert call_kwargs["top_logprobs"] == 3 + + # Case 3: auto-enabled logprobs propagates to client + llm = ChatOllama(model=MODEL_NAME, top_logprobs=3) + llm.invoke([HumanMessage("Hello")]) + + call_kwargs = mock_client.chat.call_args[1] + assert call_kwargs["logprobs"] is True + assert call_kwargs["top_logprobs"] == 3 + + # Case 4: defaults are None when not set + llm = ChatOllama(model=MODEL_NAME) + llm.invoke([HumanMessage("Hello")]) + + call_kwargs = mock_client.chat.call_args[1] + assert call_kwargs["logprobs"] is None + assert call_kwargs["top_logprobs"] is None + + +def test_top_logprobs_validation() -> None: + """Test that top_logprobs must be a positive integer.""" + with patch("langchain_ollama.chat_models.Client"): + with pytest.raises(ValueError, match="`top_logprobs` must be a positive"): + ChatOllama(model=MODEL_NAME, top_logprobs=0) + + with pytest.raises(ValueError, match="`top_logprobs` must be a positive"): + ChatOllama(model=MODEL_NAME, top_logprobs=-1) + + # Valid values should not raise + llm = ChatOllama(model=MODEL_NAME, logprobs=True, top_logprobs=1) + assert llm.top_logprobs == 1 + + +def test_top_logprobs_without_logprobs_auto_enables() -> None: + """Test that setting top_logprobs without logprobs auto-enables logprobs.""" + with patch("langchain_ollama.chat_models.Client"): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + llm = ChatOllama(model=MODEL_NAME, top_logprobs=5) + assert llm.logprobs is True + assert len(w) == 1 + assert "Setting `logprobs=True` automatically" in str(w[0].message) + + # No warning when logprobs=True explicitly + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ChatOllama(model=MODEL_NAME, logprobs=True, top_logprobs=5) + logprobs_warnings = [x for x in w if "top_logprobs" in str(x.message)] + assert len(logprobs_warnings) == 0 + + +def test_top_logprobs_with_logprobs_false_raises() -> None: + """Setting top_logprobs with logprobs=False is a contradictory config.""" + with ( + patch("langchain_ollama.chat_models.Client"), + pytest.raises(ValueError, match=r"logprobs.*explicitly.*False"), + ): + ChatOllama(model=MODEL_NAME, logprobs=False, top_logprobs=5) + + +def test_logprobs_accumulated_from_stream_into_response_metadata() -> None: + """Logprobs from intermediate streaming chunks are accumulated into the + final response_metadata when using invoke().""" + stream_responses = [ + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": "The"}, + "done": False, + "logprobs": [ + {"token": "The", "logprob": -0.5, "bytes": [84, 104, 101]}, + ], + }, + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": " sky"}, + "done": False, + "logprobs": [ + {"token": " sky", "logprob": -0.1, "bytes": [32, 115, 107, 121]}, + ], + }, + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": ""}, + "done": True, + "done_reason": "stop", + }, + ] + + with patch("langchain_ollama.chat_models.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.chat.return_value = iter(stream_responses) + + llm = ChatOllama(model=MODEL_NAME, logprobs=True) + result = llm.invoke([HumanMessage("What color is the sky?")]) + + logprobs = result.response_metadata["logprobs"] + assert len(logprobs) == 2 + assert logprobs[0]["token"] == "The" + assert logprobs[0]["logprob"] == -0.5 + assert logprobs[1]["token"] == " sky" + assert logprobs[1]["logprob"] == -0.1 + + +def test_logprobs_on_individual_streaming_chunks() -> None: + """Each streaming chunk should carry its own per-token logprobs in + response_metadata when logprobs are enabled.""" + stream_responses = [ + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": "Hi"}, + "done": False, + "logprobs": [ + {"token": "Hi", "logprob": -0.3, "bytes": [72, 105]}, + ], + }, + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": "!"}, + "done": False, + "logprobs": [ + {"token": "!", "logprob": -0.01, "bytes": [33]}, + ], + }, + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": ""}, + "done": True, + "done_reason": "stop", + }, + ] + + with patch("langchain_ollama.chat_models.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.chat.return_value = iter(stream_responses) + + llm = ChatOllama(model=MODEL_NAME, logprobs=True) + chunks = list(llm.stream([HumanMessage("Hello")])) + + assert chunks[0].response_metadata["logprobs"] == [ + {"token": "Hi", "logprob": -0.3, "bytes": [72, 105]}, + ] + + assert chunks[1].response_metadata["logprobs"] == [ + {"token": "!", "logprob": -0.01, "bytes": [33]}, + ] + + assert "logprobs" not in chunks[2].response_metadata + + +async def test_logprobs_on_individual_async_streaming_chunks() -> None: + """Async streaming chunks should carry per-token logprobs in + response_metadata when logprobs are enabled.""" + stream_responses = [ + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": "Hi"}, + "done": False, + "logprobs": [ + {"token": "Hi", "logprob": -0.3, "bytes": [72, 105]}, + ], + }, + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": "!"}, + "done": False, + "logprobs": [ + {"token": "!", "logprob": -0.01, "bytes": [33]}, + ], + }, + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": ""}, + "done": True, + "done_reason": "stop", + }, + ] + + async def async_stream_responses() -> Any: + for resp in stream_responses: + yield resp + + with patch("langchain_ollama.chat_models.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + mock_client.chat.return_value = async_stream_responses() + + llm = ChatOllama(model=MODEL_NAME, logprobs=True) + chunks = [chunk async for chunk in llm.astream([HumanMessage("Hello")])] + + assert chunks[0].response_metadata["logprobs"] == [ + {"token": "Hi", "logprob": -0.3, "bytes": [72, 105]}, + ] + + assert chunks[1].response_metadata["logprobs"] == [ + {"token": "!", "logprob": -0.01, "bytes": [33]}, + ] + + assert "logprobs" not in chunks[2].response_metadata + + +def test_logprobs_empty_list_preserved() -> None: + """An empty logprobs list `[]` should be preserved, not treated as absent.""" + stream_responses = [ + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": "Hi"}, + "done": False, + "logprobs": [], + }, + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": ""}, + "done": True, + "done_reason": "stop", + }, + ] + + with patch("langchain_ollama.chat_models.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.chat.return_value = iter(stream_responses) + + llm = ChatOllama(model=MODEL_NAME, logprobs=True) + chunks = list(llm.stream([HumanMessage("Hello")])) + + assert chunks[0].response_metadata["logprobs"] == [] + + +def test_logprobs_none_when_not_requested() -> None: + """When logprobs are not requested, response_metadata should not contain + logprobs (or it should be None).""" + stream_responses = [ + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": "Hello!"}, + "done": False, + }, + { + "model": MODEL_NAME, + "created_at": "2025-01-01T00:00:00.000000000Z", + "message": {"role": "assistant", "content": ""}, + "done": True, + "done_reason": "stop", + }, + ] + + with patch("langchain_ollama.chat_models.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.chat.return_value = iter(stream_responses) + + llm = ChatOllama(model=MODEL_NAME) + result = llm.invoke([HumanMessage("Hello")]) + + assert result.response_metadata.get("logprobs") is None + + def test_create_chat_stream_raises_when_client_none() -> None: """Test that _create_chat_stream raises RuntimeError when client is None.""" with patch("langchain_ollama.chat_models.Client") as mock_client_class: diff --git a/libs/partners/ollama/uv.lock b/libs/partners/ollama/uv.lock index f95220db7ec..961b825afb2 100644 --- a/libs/partners/ollama/uv.lock +++ b/libs/partners/ollama/uv.lock @@ -382,7 +382,7 @@ typing = [ [package.metadata] requires-dist = [ { name = "langchain-core", editable = "../../core" }, - { name = "ollama", specifier = ">=0.6.0,<1.0.0" }, + { name = "ollama", specifier = ">=0.6.1,<1.0.0" }, ] [package.metadata.requires-dev] @@ -643,15 +643,15 @@ wheels = [ [[package]] name = "ollama" -version = "0.6.0" +version = "0.6.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d6/47/f9ee32467fe92744474a8c72e138113f3b529fc266eea76abfdec9a33f3b/ollama-0.6.0.tar.gz", hash = "sha256:da2b2d846b5944cfbcee1ca1e6ee0585f6c9d45a2fe9467cbcd096a37383da2f", size = 50811, upload-time = "2025-09-24T22:46:02.417Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9d/5a/652dac4b7affc2b37b95386f8ae78f22808af09d720689e3d7a86b6ed98e/ollama-0.6.1.tar.gz", hash = "sha256:478c67546836430034b415ed64fa890fd3d1ff91781a9d548b3325274e69d7c6", size = 51620, upload-time = "2025-11-13T23:02:17.416Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b5/c1/edc9f41b425ca40b26b7c104c5f6841a4537bb2552bfa6ca66e81405bb95/ollama-0.6.0-py3-none-any.whl", hash = "sha256:534511b3ccea2dff419ae06c3b58d7f217c55be7897c8ce5868dfb6b219cf7a0", size = 14130, upload-time = "2025-09-24T22:46:01.19Z" }, + { url = "https://files.pythonhosted.org/packages/47/4f/4a617ee93d8208d2bcf26b2d8b9402ceaed03e3853c754940e2290fed063/ollama-0.6.1-py3-none-any.whl", hash = "sha256:fc4c984b345735c5486faeee67d8a265214a31cbb828167782dc642ce0a2bf8c", size = 14354, upload-time = "2025-11-13T23:02:16.292Z" }, ] [[package]]