mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(ollama): logprobs support in Ollama (#34218)
Closes #34207 --- Expose log probabilities from the Ollama Python SDK through `ChatOllama`. The ollama client already returns a `logprobs` field on chat responses for supported models, but `ChatOllama` had no way to request or surface it. ## Changes - Add `logprobs` and `top_logprobs` fields to `ChatOllama`, forwarded to the client via `_build_chat_params`. Setting `top_logprobs` without `logprobs=True` auto-enables it with a warning; setting it with `logprobs=False` raises a `ValueError` - Surface per-token logprobs on intermediate streaming chunks (both sync `_create_chat_stream` and async `_create_async_chat_stream`) via `response_metadata["logprobs"]`, accumulated into the final response on `invoke()` - Bump minimum `ollama` SDK from `>=0.6.0` to `>=0.6.1` — the version that added logprobs support --------- Co-authored-by: Mason Daugherty <mason@langchain.dev> Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
committed by
GitHub
parent
642c981d70
commit
0aa482d0cd
@@ -44,6 +44,7 @@ from __future__ import annotations
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
|
||||
from operator import itemgetter
|
||||
from typing import Any, Literal, cast
|
||||
@@ -83,7 +84,7 @@ from langchain_core.utils.function_calling import (
|
||||
)
|
||||
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
|
||||
from ollama import AsyncClient, Client, Message
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
from pydantic import BaseModel, PrivateAttr, field_validator, model_validator
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import Self, is_typeddict
|
||||
@@ -626,6 +627,31 @@ class ChatOllama(BaseChatModel):
|
||||
same prompt.
|
||||
"""
|
||||
|
||||
logprobs: bool | None = None
|
||||
"""Whether to return logprobs.
|
||||
|
||||
!!! note
|
||||
|
||||
When streaming, per-token logprobs are available on each intermediate
|
||||
chunk (via `response_metadata["logprobs"]`) and are accumulated into the
|
||||
final aggregated response when using `invoke()`.
|
||||
"""
|
||||
|
||||
top_logprobs: int | None = None
|
||||
"""Number of most likely tokens to return at each token position, each with
|
||||
an associated log probability. Must be a positive integer.
|
||||
|
||||
If set without `logprobs=True`, `logprobs` will be enabled automatically.
|
||||
"""
|
||||
|
||||
@field_validator("top_logprobs")
|
||||
@classmethod
|
||||
def _validate_top_logprobs(cls, v: int | None) -> int | None:
|
||||
if v is not None and v < 1:
|
||||
msg = "`top_logprobs` must be a positive integer."
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
|
||||
stop: list[str] | None = None
|
||||
"""Sets the stop tokens to use."""
|
||||
|
||||
@@ -772,6 +798,8 @@ class ChatOllama(BaseChatModel):
|
||||
"model": kwargs.pop("model", self.model),
|
||||
"think": kwargs.pop("reasoning", self.reasoning),
|
||||
"format": kwargs.pop("format", self.format),
|
||||
"logprobs": kwargs.pop("logprobs", self.logprobs),
|
||||
"top_logprobs": kwargs.pop("top_logprobs", self.top_logprobs),
|
||||
"options": options_dict,
|
||||
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
|
||||
**kwargs,
|
||||
@@ -790,6 +818,23 @@ class ChatOllama(BaseChatModel):
|
||||
@model_validator(mode="after")
|
||||
def _set_clients(self) -> Self:
|
||||
"""Set clients to use for ollama."""
|
||||
if self.top_logprobs is not None and self.logprobs is not True:
|
||||
if self.logprobs is False:
|
||||
msg = (
|
||||
"`top_logprobs` is set but `logprobs` is explicitly `False`. "
|
||||
"Either set `logprobs=True` to use `top_logprobs`, or remove "
|
||||
"`top_logprobs`."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
# logprobs is None (unset) — auto-enable as convenience
|
||||
self.logprobs = True
|
||||
warnings.warn(
|
||||
"`top_logprobs` is set but `logprobs` was not explicitly enabled. "
|
||||
"Setting `logprobs=True` automatically.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
client_kwargs = self.client_kwargs or {}
|
||||
|
||||
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
|
||||
@@ -1096,7 +1141,12 @@ class ChatOllama(BaseChatModel):
|
||||
generation_info["model_provider"] = "ollama"
|
||||
_ = generation_info.pop("message", None)
|
||||
else:
|
||||
generation_info = None
|
||||
chunk_logprobs = stream_resp.get("logprobs")
|
||||
generation_info = (
|
||||
{"logprobs": chunk_logprobs}
|
||||
if chunk_logprobs is not None
|
||||
else None
|
||||
)
|
||||
|
||||
additional_kwargs = {}
|
||||
if (
|
||||
@@ -1173,7 +1223,12 @@ class ChatOllama(BaseChatModel):
|
||||
generation_info["model_provider"] = "ollama"
|
||||
_ = generation_info.pop("message", None)
|
||||
else:
|
||||
generation_info = None
|
||||
chunk_logprobs = stream_resp.get("logprobs")
|
||||
generation_info = (
|
||||
{"logprobs": chunk_logprobs}
|
||||
if chunk_logprobs is not None
|
||||
else None
|
||||
)
|
||||
|
||||
additional_kwargs = {}
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user