feat(ollama): add basic auth support

This commit is contained in:
Mason Daugherty 2025-07-30 12:11:43 -04:00
parent a9e52ca605
commit 19a0761b99
No known key found for this signature in database
6 changed files with 343 additions and 33 deletions

View File

@ -1,5 +1,11 @@
"""Utility functions for validating Ollama models."""
from __future__ import annotations
import base64
from typing import Optional
from urllib.parse import unquote, urlparse
from httpx import ConnectError
from ollama import Client, ResponseError
@ -37,3 +43,51 @@ def validate_model(client: Client, model_name: str) -> None:
"Please check your Ollama server logs."
)
raise ValueError(msg) from e
def parse_url_with_auth(url: Optional[str]) -> tuple[Optional[str], Optional[dict]]:
"""Parse URL and extract authentication credentials for headers.
Handles URLs of the form: ``https://user:password@host:port/path``
Args:
url: The URL to parse. Can be None.
Returns:
A tuple of ``(cleaned_url, headers_dict)`` where:
- ``cleaned_url`` is the URL without authentication credentials
- ``headers_dict`` contains Authorization header if credentials were found
"""
if not url:
return None, None
parsed = urlparse(url)
# If no authentication info, return as-is
if not parsed.username:
return url, None
# Handle case where password might be empty string or None
password = parsed.password or ""
# Extract credentials and create basic auth header (decode percent-encoding)
username = unquote(parsed.username)
password = unquote(password)
credentials = f"{username}:{password}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers = {"Authorization": f"Basic {encoded_credentials}"}
# Reconstruct URL without authentication
cleaned_netloc = parsed.hostname or ""
if parsed.port:
cleaned_netloc += f":{parsed.port}"
cleaned_url = f"{parsed.scheme}://{cleaned_netloc}"
if parsed.path:
cleaned_url += parsed.path
if parsed.query:
cleaned_url += f"?{parsed.query}"
if parsed.fragment:
cleaned_url += f"#{parsed.fragment}"
return cleaned_url, headers

View File

@ -7,19 +7,10 @@ import json
import logging
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import (
Any,
Callable,
Literal,
Optional,
Union,
cast,
)
from typing import Any, Callable, Literal, Optional, Union, cast
from uuid import uuid4
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import LanguageModelInput
@ -57,7 +48,7 @@ from pydantic.json_schema import JsonSchemaValue
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self, is_typeddict
from ._utils import validate_model
from ._utils import parse_url_with_auth, validate_model
log = logging.getLogger(__name__)
@ -607,6 +598,15 @@ class ChatOllama(BaseChatModel):
"""Set clients to use for ollama."""
client_kwargs = self.client_kwargs or {}
# Parse URL for basic auth credentials
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
# Merge authentication headers with existing headers
if auth_headers:
headers = client_kwargs.get("headers", {})
headers.update(auth_headers)
client_kwargs = {**client_kwargs, "headers": headers}
sync_client_kwargs = client_kwargs
if self.sync_client_kwargs:
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
@ -615,8 +615,8 @@ class ChatOllama(BaseChatModel):
if self.async_client_kwargs:
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
self._client = Client(host=self.base_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
self._client = Client(host=cleaned_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self

View File

@ -6,15 +6,10 @@ from typing import Any, Optional
from langchain_core.embeddings import Embeddings
from ollama import AsyncClient, Client
from pydantic import (
BaseModel,
ConfigDict,
PrivateAttr,
model_validator,
)
from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator
from typing_extensions import Self
from ._utils import validate_model
from ._utils import parse_url_with_auth, validate_model
class OllamaEmbeddings(BaseModel, Embeddings):
@ -260,6 +255,15 @@ class OllamaEmbeddings(BaseModel, Embeddings):
"""Set clients to use for ollama."""
client_kwargs = self.client_kwargs or {}
# Parse URL for basic auth credentials
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
# Merge authentication headers with existing headers
if auth_headers:
headers = client_kwargs.get("headers", {})
headers.update(auth_headers)
client_kwargs = {**client_kwargs, "headers": headers}
sync_client_kwargs = client_kwargs
if self.sync_client_kwargs:
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
@ -268,8 +272,8 @@ class OllamaEmbeddings(BaseModel, Embeddings):
if self.async_client_kwargs:
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
self._client = Client(host=self.base_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
self._client = Client(host=cleaned_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self

View File

@ -3,12 +3,7 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import (
Any,
Literal,
Optional,
Union,
)
from typing import Any, Literal, Optional, Union
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
@ -20,7 +15,7 @@ from ollama import AsyncClient, Client, Options
from pydantic import PrivateAttr, model_validator
from typing_extensions import Self
from ._utils import validate_model
from ._utils import parse_url_with_auth, validate_model
class OllamaLLM(BaseLLM):
@ -230,6 +225,15 @@ class OllamaLLM(BaseLLM):
"""Set clients to use for ollama."""
client_kwargs = self.client_kwargs or {}
# Parse URL for basic auth credentials
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
# Merge authentication headers with existing headers
if auth_headers:
headers = client_kwargs.get("headers", {})
headers.update(auth_headers)
client_kwargs = {**client_kwargs, "headers": headers}
sync_client_kwargs = client_kwargs
if self.sync_client_kwargs:
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
@ -238,8 +242,8 @@ class OllamaLLM(BaseLLM):
if self.async_client_kwargs:
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
self._client = Client(host=self.base_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
self._client = Client(host=cleaned_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self

View File

@ -0,0 +1,248 @@
"""Test URL authentication parsing functionality."""
import base64
from unittest.mock import MagicMock, patch
from langchain_ollama._utils import parse_url_with_auth
from langchain_ollama.chat_models import ChatOllama
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_ollama.llms import OllamaLLM
class TestParseUrlWithAuth:
"""Test the parse_url_with_auth utility function."""
def test_parse_url_with_auth_none_input(self) -> None:
"""Test that None input returns None, None."""
result = parse_url_with_auth(None)
assert result == (None, None)
def test_parse_url_with_auth_no_credentials(self) -> None:
"""Test URLs without authentication credentials."""
url = "https://ollama.example.com:11434"
result = parse_url_with_auth(url)
assert result == (url, None)
def test_parse_url_with_auth_with_credentials(self) -> None:
"""Test URLs with authentication credentials."""
url = "https://user:password@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_with_path_and_query(self) -> None:
"""Test URLs with path and query parameters."""
url = "https://user:pass@ollama.example.com:11434/api/v1?timeout=30"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434/api/v1?timeout=30"
expected_credentials = base64.b64encode(b"user:pass").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_special_characters(self) -> None:
"""Test URLs with special characters in credentials."""
url = "https://user%40domain:p%40ssw0rd@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
# Note: URL parsing handles percent-encoding automatically
expected_credentials = base64.b64encode(b"user@domain:p@ssw0rd").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_only_username(self) -> None:
"""Test URLs with only username (no password)."""
url = "https://user@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_empty_password(self) -> None:
"""Test URLs with empty password."""
url = "https://user:@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
class TestChatOllamaUrlAuth:
"""Test URL authentication integration with ChatOllama."""
@patch("langchain_ollama.chat_models.Client")
@patch("langchain_ollama.chat_models.AsyncClient")
def test_chat_ollama_url_auth_integration(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that ChatOllama properly handles URL authentication."""
url_with_auth = "https://user:password@ollama.example.com:11434"
ChatOllama(
model="llama3",
base_url=url_with_auth,
)
# Verify the clients were called with cleaned URL and auth headers
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
mock_async_client.assert_called_once_with(
host=expected_url, headers=expected_headers
)
@patch("langchain_ollama.chat_models.Client")
@patch("langchain_ollama.chat_models.AsyncClient")
def test_chat_ollama_url_auth_with_existing_headers(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that URL auth headers merge with existing headers."""
url_with_auth = "https://user:password@ollama.example.com:11434"
existing_headers = {"User-Agent": "test-agent", "X-Custom": "value"}
ChatOllama(
model="llama3",
base_url=url_with_auth,
client_kwargs={"headers": existing_headers},
)
# Verify headers are merged
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {
**existing_headers,
"Authorization": f"Basic {expected_credentials}",
}
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
mock_async_client.assert_called_once_with(
host=expected_url, headers=expected_headers
)
@patch("langchain_ollama.chat_models.Client")
@patch("langchain_ollama.chat_models.AsyncClient")
def test_chat_ollama_no_url_auth(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that ChatOllama works normally without URL authentication."""
url_without_auth = "https://ollama.example.com:11434"
ChatOllama(
model="llama3",
base_url=url_without_auth,
)
# Verify no auth headers are added
mock_client.assert_called_once_with(host=url_without_auth)
mock_async_client.assert_called_once_with(host=url_without_auth)
class TestOllamaLLMUrlAuth:
"""Test URL authentication integration with OllamaLLM."""
@patch("langchain_ollama.llms.Client")
@patch("langchain_ollama.llms.AsyncClient")
def test_ollama_llm_url_auth_integration(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that OllamaLLM properly handles URL authentication."""
url_with_auth = "https://user:password@ollama.example.com:11434"
OllamaLLM(
model="llama3",
base_url=url_with_auth,
)
# Verify the clients were called with cleaned URL and auth headers
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
mock_async_client.assert_called_once_with(
host=expected_url, headers=expected_headers
)
class TestOllamaEmbeddingsUrlAuth:
"""Test URL authentication integration with OllamaEmbeddings."""
@patch("langchain_ollama.embeddings.Client")
@patch("langchain_ollama.embeddings.AsyncClient")
def test_ollama_embeddings_url_auth_integration(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that OllamaEmbeddings properly handles URL authentication."""
url_with_auth = "https://user:password@ollama.example.com:11434"
OllamaEmbeddings(
model="llama3",
base_url=url_with_auth,
)
# Verify the clients were called with cleaned URL and auth headers
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
mock_async_client.assert_called_once_with(
host=expected_url, headers=expected_headers
)
class TestUrlAuthEdgeCases:
"""Test edge cases and error conditions for URL authentication."""
def test_parse_url_with_auth_malformed_url(self) -> None:
"""Test behavior with malformed URLs."""
malformed_url = "not-a-valid-url"
result = parse_url_with_auth(malformed_url)
# Should return the original URL without modification
assert result == (malformed_url, None)
def test_parse_url_with_auth_no_port(self) -> None:
"""Test URLs without explicit port numbers."""
url = "https://user:password@ollama.example.com"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_complex_password(self) -> None:
"""Test with complex passwords containing special characters."""
# Test password with colon, which is the delimiter
url = "https://user:pass:word@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
# The parser should handle the first colon as the separator
expected_credentials = base64.b64encode(b"user:pass:word").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers

View File

@ -305,7 +305,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "0.3.70"
version = "0.3.72"
source = { editable = "../../core" }
dependencies = [
{ name = "jsonpatch" },