feat(ollama): add basic auth support (#32328)

support for URL authentication in the format
`https://user:password@host:port` for all LangChain Ollama clients.

Related to #32327 and #25055
This commit is contained in:
Mason Daugherty
2025-10-01 20:46:37 -04:00
committed by GitHub
parent a336afaecd
commit a89c549cb0
6 changed files with 409 additions and 61 deletions

View File

@@ -1,5 +1,11 @@
"""Utility function to validate Ollama models."""
from __future__ import annotations
import base64
from typing import Optional
from urllib.parse import unquote, urlparse
from httpx import ConnectError
from ollama import Client, ResponseError
@@ -40,3 +46,70 @@ def validate_model(client: Client, model_name: str) -> None:
"Please check your Ollama server logs."
)
raise ValueError(msg) from e
def parse_url_with_auth(
url: Optional[str],
) -> tuple[Optional[str], Optional[dict[str, str]]]:
"""Parse URL and extract `userinfo` credentials for headers.
Handles URLs of the form: `https://user:password@host:port/path`
Args:
url: The URL to parse.
Returns:
A tuple of ``(cleaned_url, headers_dict)`` where:
- ``cleaned_url`` is the URL without authentication credentials if any were
found. Otherwise, returns the original URL.
- ``headers_dict`` contains Authorization header if credentials were found.
"""
if not url:
return None, None
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc or not parsed.hostname:
return None, None
if not parsed.username:
return url, None
# Handle case where password might be empty string or None
password = parsed.password or ""
# Create basic auth header (decode percent-encoding)
username = unquote(parsed.username)
password = unquote(password)
credentials = f"{username}:{password}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers = {"Authorization": f"Basic {encoded_credentials}"}
# Strip credentials from URL
cleaned_netloc = parsed.hostname or ""
if parsed.port:
cleaned_netloc += f":{parsed.port}"
cleaned_url = f"{parsed.scheme}://{cleaned_netloc}"
if parsed.path:
cleaned_url += parsed.path
if parsed.query:
cleaned_url += f"?{parsed.query}"
if parsed.fragment:
cleaned_url += f"#{parsed.fragment}"
return cleaned_url, headers
def merge_auth_headers(
client_kwargs: dict,
auth_headers: Optional[dict[str, str]],
) -> None:
"""Merge authentication headers into client kwargs in-place.
Args:
client_kwargs: The client kwargs dict to update.
auth_headers: Headers to merge (typically from ``parse_url_with_auth``).
"""
if auth_headers:
headers = client_kwargs.get("headers", {})
headers.update(auth_headers)
client_kwargs["headers"] = headers

View File

@@ -7,19 +7,10 @@ import json
import logging
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import (
Any,
Callable,
Literal,
Optional,
Union,
cast,
)
from typing import Any, Callable, Literal, Optional, Union, cast
from uuid import uuid4
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import LanguageModelInput
@@ -57,7 +48,7 @@ from pydantic.json_schema import JsonSchemaValue
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self, is_typeddict
from ._utils import validate_model
from ._utils import merge_auth_headers, parse_url_with_auth, validate_model
log = logging.getLogger(__name__)
@@ -592,32 +583,50 @@ class ChatOllama(BaseChatModel):
"""How long the model will stay loaded into memory."""
base_url: Optional[str] = None
"""Base url the model is hosted under."""
"""Base url the model is hosted under.
If none, defaults to the Ollama client default.
Supports `userinfo` auth in the format `http://username:password@localhost:11434`.
Useful if your Ollama server is behind a proxy.
!!! warning
`userinfo` is not secure and should only be used for local testing or
in secure environments. Avoid using it in production or over unsecured
networks.
!!! note
If using `userinfo`, ensure that the Ollama server is configured to
accept and validate these credentials.
!!! note
`userinfo` headers are passed to both sync and async clients.
"""
client_kwargs: Optional[dict] = {}
"""Additional kwargs to pass to the httpx clients.
"""Additional kwargs to pass to the httpx clients. Pass headers in here.
These arguments are passed to both synchronous and async clients.
Use ``sync_client_kwargs`` and ``async_client_kwargs`` to pass different arguments
to synchronous and asynchronous clients.
"""
async_client_kwargs: Optional[dict] = {}
"""Additional kwargs to merge with ``client_kwargs`` before
passing to the httpx AsyncClient.
"""Additional kwargs to merge with ``client_kwargs`` before passing to httpx client.
`Full list of params. <https://www.python-httpx.org/api/#asyncclient>`__
These are clients unique to the async client; for shared args use ``client_kwargs``.
For a full list of the params, see the `httpx documentation <https://www.python-httpx.org/api/#asyncclient>`__.
"""
sync_client_kwargs: Optional[dict] = {}
"""Additional kwargs to merge with ``client_kwargs`` before
passing to the httpx Client.
"""Additional kwargs to merge with ``client_kwargs`` before passing to httpx client.
`Full list of params. <https://www.python-httpx.org/api/#client>`__
These are clients unique to the sync client; for shared args use ``client_kwargs``.
For a full list of the params, see the `httpx documentation <https://www.python-httpx.org/api/#client>`__.
"""
_client: Client = PrivateAttr()
@@ -682,6 +691,9 @@ class ChatOllama(BaseChatModel):
"""Set clients to use for ollama."""
client_kwargs = self.client_kwargs or {}
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
merge_auth_headers(client_kwargs, auth_headers)
sync_client_kwargs = client_kwargs
if self.sync_client_kwargs:
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
@@ -690,8 +702,8 @@ class ChatOllama(BaseChatModel):
if self.async_client_kwargs:
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
self._client = Client(host=self.base_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
self._client = Client(host=cleaned_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self

View File

@@ -6,15 +6,10 @@ from typing import Any, Optional
from langchain_core.embeddings import Embeddings
from ollama import AsyncClient, Client
from pydantic import (
BaseModel,
ConfigDict,
PrivateAttr,
model_validator,
)
from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator
from typing_extensions import Self
from ._utils import validate_model
from ._utils import merge_auth_headers, parse_url_with_auth, validate_model
class OllamaEmbeddings(BaseModel, Embeddings):
@@ -134,32 +129,50 @@ class OllamaEmbeddings(BaseModel, Embeddings):
"""
base_url: Optional[str] = None
"""Base url the model is hosted under."""
"""Base url the model is hosted under.
If none, defaults to the Ollama client default.
Supports `userinfo` auth in the format `http://username:password@localhost:11434`.
Useful if your Ollama server is behind a proxy.
!!! warning
`userinfo` is not secure and should only be used for local testing or
in secure environments. Avoid using it in production or over unsecured
networks.
!!! note
If using `userinfo`, ensure that the Ollama server is configured to
accept and validate these credentials.
!!! note
`userinfo` headers are passed to both sync and async clients.
"""
client_kwargs: Optional[dict] = {}
"""Additional kwargs to pass to the httpx clients.
"""Additional kwargs to pass to the httpx clients. Pass headers in here.
These arguments are passed to both synchronous and async clients.
Use ``sync_client_kwargs`` and ``async_client_kwargs`` to pass different arguments
to synchronous and asynchronous clients.
"""
async_client_kwargs: Optional[dict] = {}
"""Additional kwargs to merge with ``client_kwargs`` before passing to the httpx
AsyncClient.
"""Additional kwargs to merge with ``client_kwargs`` before passing to httpx client.
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#asyncclient>`__.
These are clients unique to the async client; for shared args use ``client_kwargs``.
For a full list of the params, see the `httpx documentation <https://www.python-httpx.org/api/#asyncclient>`__.
"""
sync_client_kwargs: Optional[dict] = {}
"""Additional kwargs to merge with ``client_kwargs`` before
passing to the HTTPX Client.
"""Additional kwargs to merge with ``client_kwargs`` before passing to httpx client.
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
These are clients unique to the sync client; for shared args use ``client_kwargs``.
For a full list of the params, see the `httpx documentation <https://www.python-httpx.org/api/#client>`__.
"""
_client: Optional[Client] = PrivateAttr(default=None)
@@ -261,6 +274,9 @@ class OllamaEmbeddings(BaseModel, Embeddings):
"""Set clients to use for Ollama."""
client_kwargs = self.client_kwargs or {}
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
merge_auth_headers(client_kwargs, auth_headers)
sync_client_kwargs = client_kwargs
if self.sync_client_kwargs:
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
@@ -269,8 +285,8 @@ class OllamaEmbeddings(BaseModel, Embeddings):
if self.async_client_kwargs:
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
self._client = Client(host=self.base_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
self._client = Client(host=cleaned_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self

View File

@@ -3,12 +3,7 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import (
Any,
Literal,
Optional,
Union,
)
from typing import Any, Literal, Optional, Union
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
@@ -20,7 +15,7 @@ from ollama import AsyncClient, Client, Options
from pydantic import PrivateAttr, model_validator
from typing_extensions import Self
from ._utils import validate_model
from ._utils import merge_auth_headers, parse_url_with_auth, validate_model
class OllamaLLM(BaseLLM):
@@ -213,32 +208,50 @@ class OllamaLLM(BaseLLM):
"""How long the model will stay loaded into memory."""
base_url: Optional[str] = None
"""Base url the model is hosted under."""
"""Base url the model is hosted under.
If none, defaults to the Ollama client default.
Supports `userinfo` auth in the format `http://username:password@localhost:11434`.
Useful if your Ollama server is behind a proxy.
!!! warning
`userinfo` is not secure and should only be used for local testing or
in secure environments. Avoid using it in production or over unsecured
networks.
!!! note
If using `userinfo`, ensure that the Ollama server is configured to
accept and validate these credentials.
!!! note
`userinfo` headers are passed to both sync and async clients.
"""
client_kwargs: Optional[dict] = {}
"""Additional kwargs to pass to the httpx clients.
"""Additional kwargs to pass to the httpx clients. Pass headers in here.
These arguments are passed to both synchronous and async clients.
Use ``sync_client_kwargs`` and ``async_client_kwargs`` to pass different arguments
to synchronous and asynchronous clients.
"""
async_client_kwargs: Optional[dict] = {}
"""Additional kwargs to merge with ``client_kwargs`` before passing to the HTTPX
AsyncClient.
"""Additional kwargs to merge with ``client_kwargs`` before passing to httpx client.
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#asyncclient>`__.
These are clients unique to the async client; for shared args use ``client_kwargs``.
For a full list of the params, see the `httpx documentation <https://www.python-httpx.org/api/#asyncclient>`__.
"""
sync_client_kwargs: Optional[dict] = {}
"""Additional kwargs to merge with ``client_kwargs`` before
passing to the HTTPX Client.
"""Additional kwargs to merge with ``client_kwargs`` before passing to httpx client.
For a full list of the params, see the `HTTPX documentation <https://www.python-httpx.org/api/#client>`__.
These are clients unique to the sync client; for shared args use ``client_kwargs``.
For a full list of the params, see the `httpx documentation <https://www.python-httpx.org/api/#client>`__.
"""
_client: Optional[Client] = PrivateAttr(default=None)
@@ -310,6 +323,9 @@ class OllamaLLM(BaseLLM):
"""Set clients to use for ollama."""
client_kwargs = self.client_kwargs or {}
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
merge_auth_headers(client_kwargs, auth_headers)
sync_client_kwargs = client_kwargs
if self.sync_client_kwargs:
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
@@ -318,8 +334,8 @@ class OllamaLLM(BaseLLM):
if self.async_client_kwargs:
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
self._client = Client(host=self.base_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
self._client = Client(host=cleaned_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self

View File

@@ -0,0 +1,231 @@
"""Test URL authentication parsing functionality."""
import base64
from unittest.mock import MagicMock, patch
from langchain_ollama._utils import parse_url_with_auth
from langchain_ollama.chat_models import ChatOllama
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_ollama.llms import OllamaLLM
MODEL_NAME = "llama3.1"
class TestParseUrlWithAuth:
"""Test the parse_url_with_auth utility function."""
def test_parse_url_with_auth_none_input(self) -> None:
"""Test that None input returns None, None."""
result = parse_url_with_auth(None)
assert result == (None, None)
def test_parse_url_with_auth_no_credentials(self) -> None:
"""Test URLs without authentication credentials."""
url = "https://ollama.example.com:11434/path?query=param"
result = parse_url_with_auth(url)
assert result == (url, None)
def test_parse_url_with_auth_with_credentials(self) -> None:
"""Test URLs with authentication credentials."""
url = "https://user:password@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_with_path_and_query(self) -> None:
"""Test URLs with auth, path, and query parameters."""
url = "https://user:pass@ollama.example.com:11434/api/v1?timeout=30"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434/api/v1?timeout=30"
expected_credentials = base64.b64encode(b"user:pass").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_special_characters(self) -> None:
"""Test URLs with special characters in credentials."""
url = "https://user%40domain:p%40ssw0rd@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
# Note: URL parsing handles percent-encoding automatically
expected_credentials = base64.b64encode(b"user@domain:p@ssw0rd").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_only_username(self) -> None:
"""Test URLs with only username (no password)."""
url = "https://user@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_empty_password(self) -> None:
"""Test URLs with empty password."""
url = "https://user:@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
class TestChatOllamaUrlAuth:
"""Test URL authentication integration with ChatOllama."""
@patch("langchain_ollama.chat_models.Client")
@patch("langchain_ollama.chat_models.AsyncClient")
def test_chat_ollama_url_auth_integration(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that ChatOllama properly handles URL authentication."""
url_with_auth = "https://user:password@ollama.example.com:11434"
ChatOllama(
model=MODEL_NAME,
base_url=url_with_auth,
)
# Verify the clients were called with cleaned URL and auth headers
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
mock_async_client.assert_called_once_with(
host=expected_url, headers=expected_headers
)
@patch("langchain_ollama.chat_models.Client")
@patch("langchain_ollama.chat_models.AsyncClient")
def test_chat_ollama_url_auth_with_existing_headers(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that URL auth headers merge with existing headers."""
url_with_auth = "https://user:password@ollama.example.com:11434"
existing_headers = {"User-Agent": "test-agent", "X-Custom": "value"}
ChatOllama(
model=MODEL_NAME,
base_url=url_with_auth,
client_kwargs={"headers": existing_headers},
)
# Verify headers are merged
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {
**existing_headers,
"Authorization": f"Basic {expected_credentials}",
}
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
mock_async_client.assert_called_once_with(
host=expected_url, headers=expected_headers
)
class TestOllamaLLMUrlAuth:
"""Test URL authentication integration with OllamaLLM."""
@patch("langchain_ollama.llms.Client")
@patch("langchain_ollama.llms.AsyncClient")
def test_ollama_llm_url_auth_integration(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that OllamaLLM properly handles URL authentication."""
url_with_auth = "https://user:password@ollama.example.com:11434"
OllamaLLM(
model=MODEL_NAME,
base_url=url_with_auth,
)
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
mock_async_client.assert_called_once_with(
host=expected_url, headers=expected_headers
)
class TestOllamaEmbeddingsUrlAuth:
"""Test URL authentication integration with OllamaEmbeddings."""
@patch("langchain_ollama.embeddings.Client")
@patch("langchain_ollama.embeddings.AsyncClient")
def test_ollama_embeddings_url_auth_integration(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that OllamaEmbeddings properly handles URL authentication."""
url_with_auth = "https://user:password@ollama.example.com:11434"
OllamaEmbeddings(
model=MODEL_NAME,
base_url=url_with_auth,
)
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
mock_async_client.assert_called_once_with(
host=expected_url, headers=expected_headers
)
class TestUrlAuthEdgeCases:
"""Test edge cases and error conditions for URL authentication."""
def test_parse_url_with_auth_malformed_url(self) -> None:
"""Test behavior with malformed URLs."""
malformed_url = "not-a-valid-url"
result = parse_url_with_auth(malformed_url)
# Shouldn't return a URL as it wouldn't parse correctly or reach a server
assert result == (None, None)
def test_parse_url_with_auth_no_port(self) -> None:
"""Test URLs without explicit port numbers."""
url = "https://user:password@ollama.example.com"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_complex_password(self) -> None:
"""Test with complex passwords containing special characters."""
# Test password with colon, which is the delimiter
url = "https://user:pass:word@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
# The parser should handle the first colon as the separator
expected_credentials = base64.b64encode(b"user:pass:word").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers

View File

@@ -323,7 +323,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "0.3.76"
version = "0.3.77"
source = { editable = "../../core" }
dependencies = [
{ name = "jsonpatch" },
@@ -435,7 +435,7 @@ typing = [
[[package]]
name = "langchain-tests"
version = "0.3.21"
version = "0.3.22"
source = { editable = "../../standard-tests" }
dependencies = [
{ name = "httpx" },