feat(ollama): logprobs support in Ollama (#34218)

Closes #34207 

---

Expose log probabilities from the Ollama Python SDK through
`ChatOllama`. The ollama client already returns a `logprobs` field on
chat responses for supported models, but `ChatOllama` had no way to
request or surface it.

## Changes
- Add `logprobs` and `top_logprobs` fields to `ChatOllama`, forwarded to
the client via `_build_chat_params`. Setting `top_logprobs` without
`logprobs=True` auto-enables it with a warning; setting it with
`logprobs=False` raises a `ValueError`
- Surface per-token logprobs on intermediate streaming chunks (both sync
`_create_chat_stream` and async `_create_async_chat_stream`) via
`response_metadata["logprobs"]`, accumulated into the final response on
`invoke()`
- Bump minimum `ollama` SDK from `>=0.6.0` to `>=0.6.1` — the version
that added logprobs support

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Mohammad Mohtashim
2026-04-07 02:06:51 +05:00
committed by GitHub
parent 642c981d70
commit 0aa482d0cd
4 changed files with 397 additions and 40 deletions

View File

@@ -44,6 +44,7 @@ from __future__ import annotations
import ast
import json
import logging
import warnings
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import Any, Literal, cast
@@ -83,7 +84,7 @@ from langchain_core.utils.function_calling import (
)
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from ollama import AsyncClient, Client, Message
from pydantic import BaseModel, PrivateAttr, model_validator
from pydantic import BaseModel, PrivateAttr, field_validator, model_validator
from pydantic.json_schema import JsonSchemaValue
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self, is_typeddict
@@ -626,6 +627,31 @@ class ChatOllama(BaseChatModel):
same prompt.
"""
logprobs: bool | None = None
"""Whether to return logprobs.
!!! note
When streaming, per-token logprobs are available on each intermediate
chunk (via `response_metadata["logprobs"]`) and are accumulated into the
final aggregated response when using `invoke()`.
"""
top_logprobs: int | None = None
"""Number of most likely tokens to return at each token position, each with
an associated log probability. Must be a positive integer.
If set without `logprobs=True`, `logprobs` will be enabled automatically.
"""
@field_validator("top_logprobs")
@classmethod
def _validate_top_logprobs(cls, v: int | None) -> int | None:
if v is not None and v < 1:
msg = "`top_logprobs` must be a positive integer."
raise ValueError(msg)
return v
stop: list[str] | None = None
"""Sets the stop tokens to use."""
@@ -772,6 +798,8 @@ class ChatOllama(BaseChatModel):
"model": kwargs.pop("model", self.model),
"think": kwargs.pop("reasoning", self.reasoning),
"format": kwargs.pop("format", self.format),
"logprobs": kwargs.pop("logprobs", self.logprobs),
"top_logprobs": kwargs.pop("top_logprobs", self.top_logprobs),
"options": options_dict,
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
**kwargs,
@@ -790,6 +818,23 @@ class ChatOllama(BaseChatModel):
@model_validator(mode="after")
def _set_clients(self) -> Self:
"""Set clients to use for ollama."""
if self.top_logprobs is not None and self.logprobs is not True:
if self.logprobs is False:
msg = (
"`top_logprobs` is set but `logprobs` is explicitly `False`. "
"Either set `logprobs=True` to use `top_logprobs`, or remove "
"`top_logprobs`."
)
raise ValueError(msg)
# logprobs is None (unset) — auto-enable as convenience
self.logprobs = True
warnings.warn(
"`top_logprobs` is set but `logprobs` was not explicitly enabled. "
"Setting `logprobs=True` automatically.",
UserWarning,
stacklevel=2,
)
client_kwargs = self.client_kwargs or {}
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
@@ -1096,7 +1141,12 @@ class ChatOllama(BaseChatModel):
generation_info["model_provider"] = "ollama"
_ = generation_info.pop("message", None)
else:
generation_info = None
chunk_logprobs = stream_resp.get("logprobs")
generation_info = (
{"logprobs": chunk_logprobs}
if chunk_logprobs is not None
else None
)
additional_kwargs = {}
if (
@@ -1173,7 +1223,12 @@ class ChatOllama(BaseChatModel):
generation_info["model_provider"] = "ollama"
_ = generation_info.pop("message", None)
else:
generation_info = None
chunk_logprobs = stream_resp.get("logprobs")
generation_info = (
{"logprobs": chunk_logprobs}
if chunk_logprobs is not None
else None
)
additional_kwargs = {}
if (

View File

@@ -23,7 +23,7 @@ classifiers = [
version = "1.0.1"
requires-python = ">=3.10.0,<4.0.0"
dependencies = [
"ollama>=0.6.0,<1.0.0",
"ollama>=0.6.1,<1.0.0",
"langchain-core>=1.2.21,<2.0.0",
]
@@ -114,6 +114,7 @@ asyncio_mode = "auto"
[tool.ruff.lint.extend-per-file-ignores]
"tests/**/*.py" = [
"S101", # Tests need assertions
"S105", # False positive on dict key "token" in logprobs assertions
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"ARG001", # Unused function arguments in tests (e.g. kwargs)
"PLR2004", # Magic value in comparisons

View File

@@ -2,13 +2,11 @@
import json
import logging
from collections.abc import Generator
from contextlib import contextmanager
import warnings
from typing import Any
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import Client, Request, Response
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import ChatMessage, HumanMessage
from langchain_tests.unit_tests import ChatModelUnitTests
@@ -22,17 +20,6 @@ from langchain_ollama.chat_models import (
MODEL_NAME = "llama3.1"
@contextmanager
def _mock_httpx_client_stream(
*_args: Any, **_kwargs: Any
) -> Generator[Response, Any, Any]:
yield Response(
status_code=200,
content='{"message": {"role": "assistant", "content": "The meaning ..."}}',
request=Request(method="POST", url="http://whocares:11434"),
)
dummy_raw_tool_call = {
"function": {"name": "test_func", "arguments": ""},
}
@@ -105,25 +92,37 @@ def test__parse_arguments_from_tool_call_with_function_name_metadata() -> None:
assert response_different == {"functionName": "function_b"}
def test_arbitrary_roles_accepted_in_chatmessages(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def test_arbitrary_roles_accepted_in_chatmessages() -> None:
"""Test that `ChatOllama` accepts arbitrary roles in `ChatMessage`."""
monkeypatch.setattr(Client, "stream", _mock_httpx_client_stream)
llm = ChatOllama(
model=MODEL_NAME,
verbose=True,
format=None,
)
messages = [
ChatMessage(
role="somerandomrole",
content="I'm ok with you adding any role message now!",
),
ChatMessage(role="control", content="thinking"),
ChatMessage(role="user", content="What is the meaning of life?"),
response = [
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"done": True,
"done_reason": "stop",
"message": {"role": "assistant", "content": "The meaning of life..."},
}
]
llm.invoke(messages)
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = response
llm = ChatOllama(
model=MODEL_NAME,
verbose=True,
format=None,
)
messages = [
ChatMessage(
role="somerandomrole",
content="I'm ok with you adding any role message now!",
),
ChatMessage(role="control", content="thinking"),
ChatMessage(role="user", content="What is the meaning of life?"),
]
llm.invoke(messages)
@patch("langchain_ollama.chat_models.validate_model")
@@ -449,6 +448,308 @@ def test_reasoning_param_passed_to_client() -> None:
assert call_kwargs["think"] is True
def test_logprobs_params_passed_to_client() -> None:
"""Test that logprobs parameters are correctly passed to the Ollama client."""
response = [
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": "Hello!"},
"done": True,
"done_reason": "stop",
}
]
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = response
# Case 1: logprobs=True, top_logprobs=5 in init
llm = ChatOllama(model=MODEL_NAME, logprobs=True, top_logprobs=5)
llm.invoke([HumanMessage("Hello")])
call_kwargs = mock_client.chat.call_args[1]
assert call_kwargs["logprobs"] is True
assert call_kwargs["top_logprobs"] == 5
# Case 2: override via invoke kwargs
llm = ChatOllama(model=MODEL_NAME)
llm.invoke([HumanMessage("Hello")], logprobs=True, top_logprobs=3)
call_kwargs = mock_client.chat.call_args[1]
assert call_kwargs["logprobs"] is True
assert call_kwargs["top_logprobs"] == 3
# Case 3: auto-enabled logprobs propagates to client
llm = ChatOllama(model=MODEL_NAME, top_logprobs=3)
llm.invoke([HumanMessage("Hello")])
call_kwargs = mock_client.chat.call_args[1]
assert call_kwargs["logprobs"] is True
assert call_kwargs["top_logprobs"] == 3
# Case 4: defaults are None when not set
llm = ChatOllama(model=MODEL_NAME)
llm.invoke([HumanMessage("Hello")])
call_kwargs = mock_client.chat.call_args[1]
assert call_kwargs["logprobs"] is None
assert call_kwargs["top_logprobs"] is None
def test_top_logprobs_validation() -> None:
"""Test that top_logprobs must be a positive integer."""
with patch("langchain_ollama.chat_models.Client"):
with pytest.raises(ValueError, match="`top_logprobs` must be a positive"):
ChatOllama(model=MODEL_NAME, top_logprobs=0)
with pytest.raises(ValueError, match="`top_logprobs` must be a positive"):
ChatOllama(model=MODEL_NAME, top_logprobs=-1)
# Valid values should not raise
llm = ChatOllama(model=MODEL_NAME, logprobs=True, top_logprobs=1)
assert llm.top_logprobs == 1
def test_top_logprobs_without_logprobs_auto_enables() -> None:
"""Test that setting top_logprobs without logprobs auto-enables logprobs."""
with patch("langchain_ollama.chat_models.Client"):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
llm = ChatOllama(model=MODEL_NAME, top_logprobs=5)
assert llm.logprobs is True
assert len(w) == 1
assert "Setting `logprobs=True` automatically" in str(w[0].message)
# No warning when logprobs=True explicitly
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
ChatOllama(model=MODEL_NAME, logprobs=True, top_logprobs=5)
logprobs_warnings = [x for x in w if "top_logprobs" in str(x.message)]
assert len(logprobs_warnings) == 0
def test_top_logprobs_with_logprobs_false_raises() -> None:
"""Setting top_logprobs with logprobs=False is a contradictory config."""
with (
patch("langchain_ollama.chat_models.Client"),
pytest.raises(ValueError, match=r"logprobs.*explicitly.*False"),
):
ChatOllama(model=MODEL_NAME, logprobs=False, top_logprobs=5)
def test_logprobs_accumulated_from_stream_into_response_metadata() -> None:
"""Logprobs from intermediate streaming chunks are accumulated into the
final response_metadata when using invoke()."""
stream_responses = [
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": "The"},
"done": False,
"logprobs": [
{"token": "The", "logprob": -0.5, "bytes": [84, 104, 101]},
],
},
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": " sky"},
"done": False,
"logprobs": [
{"token": " sky", "logprob": -0.1, "bytes": [32, 115, 107, 121]},
],
},
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": ""},
"done": True,
"done_reason": "stop",
},
]
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = iter(stream_responses)
llm = ChatOllama(model=MODEL_NAME, logprobs=True)
result = llm.invoke([HumanMessage("What color is the sky?")])
logprobs = result.response_metadata["logprobs"]
assert len(logprobs) == 2
assert logprobs[0]["token"] == "The"
assert logprobs[0]["logprob"] == -0.5
assert logprobs[1]["token"] == " sky"
assert logprobs[1]["logprob"] == -0.1
def test_logprobs_on_individual_streaming_chunks() -> None:
"""Each streaming chunk should carry its own per-token logprobs in
response_metadata when logprobs are enabled."""
stream_responses = [
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": "Hi"},
"done": False,
"logprobs": [
{"token": "Hi", "logprob": -0.3, "bytes": [72, 105]},
],
},
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": "!"},
"done": False,
"logprobs": [
{"token": "!", "logprob": -0.01, "bytes": [33]},
],
},
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": ""},
"done": True,
"done_reason": "stop",
},
]
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = iter(stream_responses)
llm = ChatOllama(model=MODEL_NAME, logprobs=True)
chunks = list(llm.stream([HumanMessage("Hello")]))
assert chunks[0].response_metadata["logprobs"] == [
{"token": "Hi", "logprob": -0.3, "bytes": [72, 105]},
]
assert chunks[1].response_metadata["logprobs"] == [
{"token": "!", "logprob": -0.01, "bytes": [33]},
]
assert "logprobs" not in chunks[2].response_metadata
async def test_logprobs_on_individual_async_streaming_chunks() -> None:
"""Async streaming chunks should carry per-token logprobs in
response_metadata when logprobs are enabled."""
stream_responses = [
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": "Hi"},
"done": False,
"logprobs": [
{"token": "Hi", "logprob": -0.3, "bytes": [72, 105]},
],
},
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": "!"},
"done": False,
"logprobs": [
{"token": "!", "logprob": -0.01, "bytes": [33]},
],
},
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": ""},
"done": True,
"done_reason": "stop",
},
]
async def async_stream_responses() -> Any:
for resp in stream_responses:
yield resp
with patch("langchain_ollama.chat_models.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = async_stream_responses()
llm = ChatOllama(model=MODEL_NAME, logprobs=True)
chunks = [chunk async for chunk in llm.astream([HumanMessage("Hello")])]
assert chunks[0].response_metadata["logprobs"] == [
{"token": "Hi", "logprob": -0.3, "bytes": [72, 105]},
]
assert chunks[1].response_metadata["logprobs"] == [
{"token": "!", "logprob": -0.01, "bytes": [33]},
]
assert "logprobs" not in chunks[2].response_metadata
def test_logprobs_empty_list_preserved() -> None:
"""An empty logprobs list `[]` should be preserved, not treated as absent."""
stream_responses = [
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": "Hi"},
"done": False,
"logprobs": [],
},
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": ""},
"done": True,
"done_reason": "stop",
},
]
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = iter(stream_responses)
llm = ChatOllama(model=MODEL_NAME, logprobs=True)
chunks = list(llm.stream([HumanMessage("Hello")]))
assert chunks[0].response_metadata["logprobs"] == []
def test_logprobs_none_when_not_requested() -> None:
"""When logprobs are not requested, response_metadata should not contain
logprobs (or it should be None)."""
stream_responses = [
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": "Hello!"},
"done": False,
},
{
"model": MODEL_NAME,
"created_at": "2025-01-01T00:00:00.000000000Z",
"message": {"role": "assistant", "content": ""},
"done": True,
"done_reason": "stop",
},
]
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.chat.return_value = iter(stream_responses)
llm = ChatOllama(model=MODEL_NAME)
result = llm.invoke([HumanMessage("Hello")])
assert result.response_metadata.get("logprobs") is None
def test_create_chat_stream_raises_when_client_none() -> None:
"""Test that _create_chat_stream raises RuntimeError when client is None."""
with patch("langchain_ollama.chat_models.Client") as mock_client_class:

View File

@@ -382,7 +382,7 @@ typing = [
[package.metadata]
requires-dist = [
{ name = "langchain-core", editable = "../../core" },
{ name = "ollama", specifier = ">=0.6.0,<1.0.0" },
{ name = "ollama", specifier = ">=0.6.1,<1.0.0" },
]
[package.metadata.requires-dev]
@@ -643,15 +643,15 @@ wheels = [
[[package]]
name = "ollama"
version = "0.6.0"
version = "0.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
{ name = "pydantic" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d6/47/f9ee32467fe92744474a8c72e138113f3b529fc266eea76abfdec9a33f3b/ollama-0.6.0.tar.gz", hash = "sha256:da2b2d846b5944cfbcee1ca1e6ee0585f6c9d45a2fe9467cbcd096a37383da2f", size = 50811, upload-time = "2025-09-24T22:46:02.417Z" }
sdist = { url = "https://files.pythonhosted.org/packages/9d/5a/652dac4b7affc2b37b95386f8ae78f22808af09d720689e3d7a86b6ed98e/ollama-0.6.1.tar.gz", hash = "sha256:478c67546836430034b415ed64fa890fd3d1ff91781a9d548b3325274e69d7c6", size = 51620, upload-time = "2025-11-13T23:02:17.416Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/c1/edc9f41b425ca40b26b7c104c5f6841a4537bb2552bfa6ca66e81405bb95/ollama-0.6.0-py3-none-any.whl", hash = "sha256:534511b3ccea2dff419ae06c3b58d7f217c55be7897c8ce5868dfb6b219cf7a0", size = 14130, upload-time = "2025-09-24T22:46:01.19Z" },
{ url = "https://files.pythonhosted.org/packages/47/4f/4a617ee93d8208d2bcf26b2d8b9402ceaed03e3853c754940e2290fed063/ollama-0.6.1-py3-none-any.whl", hash = "sha256:fc4c984b345735c5486faeee67d8a265214a31cbb828167782dc642ce0a2bf8c", size = 14354, upload-time = "2025-11-13T23:02:16.292Z" },
]
[[package]]