feat(ollama): add basic auth support

This commit is contained in:
Mason Daugherty 2025-07-30 12:11:43 -04:00
parent a9e52ca605
commit 19a0761b99
No known key found for this signature in database
6 changed files with 343 additions and 33 deletions

View File

@ -1,5 +1,11 @@
"""Utility functions for validating Ollama models.""" """Utility functions for validating Ollama models."""
from __future__ import annotations
import base64
from typing import Optional
from urllib.parse import unquote, urlparse
from httpx import ConnectError from httpx import ConnectError
from ollama import Client, ResponseError from ollama import Client, ResponseError
@ -37,3 +43,51 @@ def validate_model(client: Client, model_name: str) -> None:
"Please check your Ollama server logs." "Please check your Ollama server logs."
) )
raise ValueError(msg) from e raise ValueError(msg) from e
def parse_url_with_auth(url: Optional[str]) -> tuple[Optional[str], Optional[dict]]:
"""Parse URL and extract authentication credentials for headers.
Handles URLs of the form: ``https://user:password@host:port/path``
Args:
url: The URL to parse. Can be None.
Returns:
A tuple of ``(cleaned_url, headers_dict)`` where:
- ``cleaned_url`` is the URL without authentication credentials
- ``headers_dict`` contains Authorization header if credentials were found
"""
if not url:
return None, None
parsed = urlparse(url)
# If no authentication info, return as-is
if not parsed.username:
return url, None
# Handle case where password might be empty string or None
password = parsed.password or ""
# Extract credentials and create basic auth header (decode percent-encoding)
username = unquote(parsed.username)
password = unquote(password)
credentials = f"{username}:{password}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers = {"Authorization": f"Basic {encoded_credentials}"}
# Reconstruct URL without authentication
cleaned_netloc = parsed.hostname or ""
if parsed.port:
cleaned_netloc += f":{parsed.port}"
cleaned_url = f"{parsed.scheme}://{cleaned_netloc}"
if parsed.path:
cleaned_url += parsed.path
if parsed.query:
cleaned_url += f"?{parsed.query}"
if parsed.fragment:
cleaned_url += f"#{parsed.fragment}"
return cleaned_url, headers

View File

@ -7,19 +7,10 @@ import json
import logging import logging
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import Any, Callable, Literal, Optional, Union, cast
Any,
Callable,
Literal,
Optional,
Union,
cast,
)
from uuid import uuid4 from uuid import uuid4
from langchain_core.callbacks import ( from langchain_core.callbacks import CallbackManagerForLLMRun
CallbackManagerForLLMRun,
)
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import LanguageModelInput from langchain_core.language_models import LanguageModelInput
@ -57,7 +48,7 @@ from pydantic.json_schema import JsonSchemaValue
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self, is_typeddict from typing_extensions import Self, is_typeddict
from ._utils import validate_model from ._utils import parse_url_with_auth, validate_model
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -607,6 +598,15 @@ class ChatOllama(BaseChatModel):
"""Set clients to use for ollama.""" """Set clients to use for ollama."""
client_kwargs = self.client_kwargs or {} client_kwargs = self.client_kwargs or {}
# Parse URL for basic auth credentials
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
# Merge authentication headers with existing headers
if auth_headers:
headers = client_kwargs.get("headers", {})
headers.update(auth_headers)
client_kwargs = {**client_kwargs, "headers": headers}
sync_client_kwargs = client_kwargs sync_client_kwargs = client_kwargs
if self.sync_client_kwargs: if self.sync_client_kwargs:
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs} sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
@ -615,8 +615,8 @@ class ChatOllama(BaseChatModel):
if self.async_client_kwargs: if self.async_client_kwargs:
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs} async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
self._client = Client(host=self.base_url, **sync_client_kwargs) self._client = Client(host=cleaned_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
if self.validate_model_on_init: if self.validate_model_on_init:
validate_model(self._client, self.model) validate_model(self._client, self.model)
return self return self

View File

@ -6,15 +6,10 @@ from typing import Any, Optional
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from ollama import AsyncClient, Client from ollama import AsyncClient, Client
from pydantic import ( from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator
BaseModel,
ConfigDict,
PrivateAttr,
model_validator,
)
from typing_extensions import Self from typing_extensions import Self
from ._utils import validate_model from ._utils import parse_url_with_auth, validate_model
class OllamaEmbeddings(BaseModel, Embeddings): class OllamaEmbeddings(BaseModel, Embeddings):
@ -260,6 +255,15 @@ class OllamaEmbeddings(BaseModel, Embeddings):
"""Set clients to use for ollama.""" """Set clients to use for ollama."""
client_kwargs = self.client_kwargs or {} client_kwargs = self.client_kwargs or {}
# Parse URL for basic auth credentials
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
# Merge authentication headers with existing headers
if auth_headers:
headers = client_kwargs.get("headers", {})
headers.update(auth_headers)
client_kwargs = {**client_kwargs, "headers": headers}
sync_client_kwargs = client_kwargs sync_client_kwargs = client_kwargs
if self.sync_client_kwargs: if self.sync_client_kwargs:
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs} sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
@ -268,8 +272,8 @@ class OllamaEmbeddings(BaseModel, Embeddings):
if self.async_client_kwargs: if self.async_client_kwargs:
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs} async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
self._client = Client(host=self.base_url, **sync_client_kwargs) self._client = Client(host=cleaned_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
if self.validate_model_on_init: if self.validate_model_on_init:
validate_model(self._client, self.model) validate_model(self._client, self.model)
return self return self

View File

@ -3,12 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncIterator, Iterator, Mapping from collections.abc import AsyncIterator, Iterator, Mapping
from typing import ( from typing import Any, Literal, Optional, Union
Any,
Literal,
Optional,
Union,
)
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -20,7 +15,7 @@ from ollama import AsyncClient, Client, Options
from pydantic import PrivateAttr, model_validator from pydantic import PrivateAttr, model_validator
from typing_extensions import Self from typing_extensions import Self
from ._utils import validate_model from ._utils import parse_url_with_auth, validate_model
class OllamaLLM(BaseLLM): class OllamaLLM(BaseLLM):
@ -230,6 +225,15 @@ class OllamaLLM(BaseLLM):
"""Set clients to use for ollama.""" """Set clients to use for ollama."""
client_kwargs = self.client_kwargs or {} client_kwargs = self.client_kwargs or {}
# Parse URL for basic auth credentials
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
# Merge authentication headers with existing headers
if auth_headers:
headers = client_kwargs.get("headers", {})
headers.update(auth_headers)
client_kwargs = {**client_kwargs, "headers": headers}
sync_client_kwargs = client_kwargs sync_client_kwargs = client_kwargs
if self.sync_client_kwargs: if self.sync_client_kwargs:
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs} sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
@ -238,8 +242,8 @@ class OllamaLLM(BaseLLM):
if self.async_client_kwargs: if self.async_client_kwargs:
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs} async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
self._client = Client(host=self.base_url, **sync_client_kwargs) self._client = Client(host=cleaned_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
if self.validate_model_on_init: if self.validate_model_on_init:
validate_model(self._client, self.model) validate_model(self._client, self.model)
return self return self

View File

@ -0,0 +1,248 @@
"""Test URL authentication parsing functionality."""
import base64
from unittest.mock import MagicMock, patch
from langchain_ollama._utils import parse_url_with_auth
from langchain_ollama.chat_models import ChatOllama
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_ollama.llms import OllamaLLM
class TestParseUrlWithAuth:
"""Test the parse_url_with_auth utility function."""
def test_parse_url_with_auth_none_input(self) -> None:
"""Test that None input returns None, None."""
result = parse_url_with_auth(None)
assert result == (None, None)
def test_parse_url_with_auth_no_credentials(self) -> None:
"""Test URLs without authentication credentials."""
url = "https://ollama.example.com:11434"
result = parse_url_with_auth(url)
assert result == (url, None)
def test_parse_url_with_auth_with_credentials(self) -> None:
"""Test URLs with authentication credentials."""
url = "https://user:password@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_with_path_and_query(self) -> None:
"""Test URLs with path and query parameters."""
url = "https://user:pass@ollama.example.com:11434/api/v1?timeout=30"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434/api/v1?timeout=30"
expected_credentials = base64.b64encode(b"user:pass").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_special_characters(self) -> None:
"""Test URLs with special characters in credentials."""
url = "https://user%40domain:p%40ssw0rd@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
# Note: URL parsing handles percent-encoding automatically
expected_credentials = base64.b64encode(b"user@domain:p@ssw0rd").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_only_username(self) -> None:
"""Test URLs with only username (no password)."""
url = "https://user@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_empty_password(self) -> None:
"""Test URLs with empty password."""
url = "https://user:@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
class TestChatOllamaUrlAuth:
"""Test URL authentication integration with ChatOllama."""
@patch("langchain_ollama.chat_models.Client")
@patch("langchain_ollama.chat_models.AsyncClient")
def test_chat_ollama_url_auth_integration(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that ChatOllama properly handles URL authentication."""
url_with_auth = "https://user:password@ollama.example.com:11434"
ChatOllama(
model="llama3",
base_url=url_with_auth,
)
# Verify the clients were called with cleaned URL and auth headers
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
mock_async_client.assert_called_once_with(
host=expected_url, headers=expected_headers
)
@patch("langchain_ollama.chat_models.Client")
@patch("langchain_ollama.chat_models.AsyncClient")
def test_chat_ollama_url_auth_with_existing_headers(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that URL auth headers merge with existing headers."""
url_with_auth = "https://user:password@ollama.example.com:11434"
existing_headers = {"User-Agent": "test-agent", "X-Custom": "value"}
ChatOllama(
model="llama3",
base_url=url_with_auth,
client_kwargs={"headers": existing_headers},
)
# Verify headers are merged
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {
**existing_headers,
"Authorization": f"Basic {expected_credentials}",
}
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
mock_async_client.assert_called_once_with(
host=expected_url, headers=expected_headers
)
@patch("langchain_ollama.chat_models.Client")
@patch("langchain_ollama.chat_models.AsyncClient")
def test_chat_ollama_no_url_auth(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that ChatOllama works normally without URL authentication."""
url_without_auth = "https://ollama.example.com:11434"
ChatOllama(
model="llama3",
base_url=url_without_auth,
)
# Verify no auth headers are added
mock_client.assert_called_once_with(host=url_without_auth)
mock_async_client.assert_called_once_with(host=url_without_auth)
class TestOllamaLLMUrlAuth:
"""Test URL authentication integration with OllamaLLM."""
@patch("langchain_ollama.llms.Client")
@patch("langchain_ollama.llms.AsyncClient")
def test_ollama_llm_url_auth_integration(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that OllamaLLM properly handles URL authentication."""
url_with_auth = "https://user:password@ollama.example.com:11434"
OllamaLLM(
model="llama3",
base_url=url_with_auth,
)
# Verify the clients were called with cleaned URL and auth headers
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
mock_async_client.assert_called_once_with(
host=expected_url, headers=expected_headers
)
class TestOllamaEmbeddingsUrlAuth:
"""Test URL authentication integration with OllamaEmbeddings."""
@patch("langchain_ollama.embeddings.Client")
@patch("langchain_ollama.embeddings.AsyncClient")
def test_ollama_embeddings_url_auth_integration(
self, mock_async_client: MagicMock, mock_client: MagicMock
) -> None:
"""Test that OllamaEmbeddings properly handles URL authentication."""
url_with_auth = "https://user:password@ollama.example.com:11434"
OllamaEmbeddings(
model="llama3",
base_url=url_with_auth,
)
# Verify the clients were called with cleaned URL and auth headers
expected_url = "https://ollama.example.com:11434"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
mock_async_client.assert_called_once_with(
host=expected_url, headers=expected_headers
)
class TestUrlAuthEdgeCases:
"""Test edge cases and error conditions for URL authentication."""
def test_parse_url_with_auth_malformed_url(self) -> None:
"""Test behavior with malformed URLs."""
malformed_url = "not-a-valid-url"
result = parse_url_with_auth(malformed_url)
# Should return the original URL without modification
assert result == (malformed_url, None)
def test_parse_url_with_auth_no_port(self) -> None:
"""Test URLs without explicit port numbers."""
url = "https://user:password@ollama.example.com"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com"
expected_credentials = base64.b64encode(b"user:password").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers
def test_parse_url_with_auth_complex_password(self) -> None:
"""Test with complex passwords containing special characters."""
# Test password with colon, which is the delimiter
url = "https://user:pass:word@ollama.example.com:11434"
cleaned_url, headers = parse_url_with_auth(url)
expected_url = "https://ollama.example.com:11434"
# The parser should handle the first colon as the separator
expected_credentials = base64.b64encode(b"user:pass:word").decode()
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
assert cleaned_url == expected_url
assert headers == expected_headers

View File

@ -305,7 +305,7 @@ wheels = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.3.70" version = "0.3.72"
source = { editable = "../../core" } source = { editable = "../../core" }
dependencies = [ dependencies = [
{ name = "jsonpatch" }, { name = "jsonpatch" },