feat(ollama): logprobs support in Ollama (#34218)

Closes #34207 

---

Expose log probabilities from the Ollama Python SDK through
`ChatOllama`. The ollama client already returns a `logprobs` field on
chat responses for supported models, but `ChatOllama` had no way to
request or surface it.

## Changes
- Add `logprobs` and `top_logprobs` fields to `ChatOllama`, forwarded to
the client via `_build_chat_params`. Setting `top_logprobs` without
`logprobs=True` auto-enables it with a warning; setting it with
`logprobs=False` raises a `ValueError`
- Surface per-token logprobs on intermediate streaming chunks (both sync
`_create_chat_stream` and async `_create_async_chat_stream`) via
`response_metadata["logprobs"]`, accumulated into the final response on
`invoke()`
- Bump minimum `ollama` SDK from `>=0.6.0` to `>=0.6.1` — the version
that added logprobs support

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Mohammad Mohtashim
2026-04-07 02:06:51 +05:00
committed by GitHub
parent 642c981d70
commit 0aa482d0cd
4 changed files with 397 additions and 40 deletions

View File

@@ -44,6 +44,7 @@ from __future__ import annotations
import ast
import json
import logging
import warnings
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import Any, Literal, cast
@@ -83,7 +84,7 @@ from langchain_core.utils.function_calling import (
)
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from ollama import AsyncClient, Client, Message
from pydantic import BaseModel, PrivateAttr, model_validator
from pydantic import BaseModel, PrivateAttr, field_validator, model_validator
from pydantic.json_schema import JsonSchemaValue
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self, is_typeddict
@@ -626,6 +627,31 @@ class ChatOllama(BaseChatModel):
same prompt.
"""
logprobs: bool | None = None
"""Whether to return logprobs.
!!! note
When streaming, per-token logprobs are available on each intermediate
chunk (via `response_metadata["logprobs"]`) and are accumulated into the
final aggregated response when using `invoke()`.
"""
top_logprobs: int | None = None
"""Number of most likely tokens to return at each token position, each with
an associated log probability. Must be a positive integer.
If set without `logprobs=True`, `logprobs` will be enabled automatically.
"""
@field_validator("top_logprobs")
@classmethod
def _validate_top_logprobs(cls, v: int | None) -> int | None:
if v is not None and v < 1:
msg = "`top_logprobs` must be a positive integer."
raise ValueError(msg)
return v
stop: list[str] | None = None
"""Sets the stop tokens to use."""
@@ -772,6 +798,8 @@ class ChatOllama(BaseChatModel):
"model": kwargs.pop("model", self.model),
"think": kwargs.pop("reasoning", self.reasoning),
"format": kwargs.pop("format", self.format),
"logprobs": kwargs.pop("logprobs", self.logprobs),
"top_logprobs": kwargs.pop("top_logprobs", self.top_logprobs),
"options": options_dict,
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
**kwargs,
@@ -790,6 +818,23 @@ class ChatOllama(BaseChatModel):
@model_validator(mode="after")
def _set_clients(self) -> Self:
"""Set clients to use for ollama."""
if self.top_logprobs is not None and self.logprobs is not True:
if self.logprobs is False:
msg = (
"`top_logprobs` is set but `logprobs` is explicitly `False`. "
"Either set `logprobs=True` to use `top_logprobs`, or remove "
"`top_logprobs`."
)
raise ValueError(msg)
# logprobs is None (unset) — auto-enable as convenience
self.logprobs = True
warnings.warn(
"`top_logprobs` is set but `logprobs` was not explicitly enabled. "
"Setting `logprobs=True` automatically.",
UserWarning,
stacklevel=2,
)
client_kwargs = self.client_kwargs or {}
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
@@ -1096,7 +1141,12 @@ class ChatOllama(BaseChatModel):
generation_info["model_provider"] = "ollama"
_ = generation_info.pop("message", None)
else:
generation_info = None
chunk_logprobs = stream_resp.get("logprobs")
generation_info = (
{"logprobs": chunk_logprobs}
if chunk_logprobs is not None
else None
)
additional_kwargs = {}
if (
@@ -1173,7 +1223,12 @@ class ChatOllama(BaseChatModel):
generation_info["model_provider"] = "ollama"
_ = generation_info.pop("message", None)
else:
generation_info = None
chunk_logprobs = stream_resp.get("logprobs")
generation_info = (
{"logprobs": chunk_logprobs}
if chunk_logprobs is not None
else None
)
additional_kwargs = {}
if (