feat(ollama): add dimensions to OllamaEmbeddings (#36543)

Fixes #34623

Add `dimensions` field to `OllamaEmbeddings` to allow users to specify 
output embedding size for models that support variable dimensions . The
field is passed
directly to the Ollama client's `embed()` call for both sync and async
methods.

**How I verified it works:**
- Ran unit tests: `python -m pytest tests/unit_tests/ -v`
- Ran integration tests against a live Ollama instance:
`OLLAMA_HOST=http://ollama:11434 python -m pytest
tests/integration_tests/ -v`
- Confirmed that passing `dimensions=768` no longer raises
`extra_forbidden`
  Pydantic validation error and returns embeddings of the expected size.

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Dat Nguyen
2026-04-06 21:50:54 -04:00
committed by GitHub
parent 050b779d97
commit e71e6564b1
2 changed files with 82 additions and 3 deletions

View File

@@ -6,7 +6,13 @@ from typing import Any
from langchain_core.embeddings import Embeddings
from ollama import AsyncClient, Client
from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator
from pydantic import (
BaseModel,
ConfigDict,
PrivateAttr,
field_validator,
model_validator,
)
from typing_extensions import Self
from langchain_ollama._utils import (
@@ -124,6 +130,20 @@ class OllamaEmbeddings(BaseModel, Embeddings):
model: str
"""Model name to use."""
dimensions: int | None = None
"""Number of dimensions for the output embedding vectors.
If not provided, the model's default embedding dimensionality is used.
"""
@field_validator("dimensions")
@classmethod
def _validate_dimensions(cls, v: int | None) -> int | None:
if v is not None and v < 1:
msg = "`dimensions` must be a positive integer."
raise ValueError(msg)
return v
validate_model_on_init: bool = False
"""Whether to validate the model exists in ollama locally on initialization.
@@ -303,7 +323,11 @@ class OllamaEmbeddings(BaseModel, Embeddings):
)
raise RuntimeError(msg)
return self._client.embed(
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
self.model,
texts,
dimensions=self.dimensions,
options=self._default_params,
keep_alive=self.keep_alive,
)["embeddings"]
def embed_query(self, text: str) -> list[float]:
@@ -322,6 +346,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
await self._async_client.embed(
self.model,
texts,
dimensions=self.dimensions,
options=self._default_params,
keep_alive=self.keep_alive,
)