From 19a0761b9917d77c496ce05176fef33ad553e6d8 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Wed, 30 Jul 2025 12:11:43 -0400 Subject: [PATCH] feat(ollama): add basic auth support --- .../ollama/langchain_ollama/_utils.py | 54 ++++ .../ollama/langchain_ollama/chat_models.py | 28 +- .../ollama/langchain_ollama/embeddings.py | 22 +- libs/partners/ollama/langchain_ollama/llms.py | 22 +- .../ollama/tests/unit_tests/test_url_auth.py | 248 ++++++++++++++++++ libs/partners/ollama/uv.lock | 2 +- 6 files changed, 343 insertions(+), 33 deletions(-) create mode 100644 libs/partners/ollama/tests/unit_tests/test_url_auth.py diff --git a/libs/partners/ollama/langchain_ollama/_utils.py b/libs/partners/ollama/langchain_ollama/_utils.py index f3cd6fe9a4d..97027cfb271 100644 --- a/libs/partners/ollama/langchain_ollama/_utils.py +++ b/libs/partners/ollama/langchain_ollama/_utils.py @@ -1,5 +1,11 @@ """Utility functions for validating Ollama models.""" +from __future__ import annotations + +import base64 +from typing import Optional +from urllib.parse import unquote, urlparse + from httpx import ConnectError from ollama import Client, ResponseError @@ -37,3 +43,51 @@ def validate_model(client: Client, model_name: str) -> None: "Please check your Ollama server logs." ) raise ValueError(msg) from e + + +def parse_url_with_auth(url: Optional[str]) -> tuple[Optional[str], Optional[dict]]: + """Parse URL and extract authentication credentials for headers. + + Handles URLs of the form: ``https://user:password@host:port/path`` + + Args: + url: The URL to parse. Can be None. + + Returns: + A tuple of ``(cleaned_url, headers_dict)`` where: + - ``cleaned_url`` is the URL without authentication credentials + - ``headers_dict`` contains Authorization header if credentials were found + """ + if not url: + return None, None + + parsed = urlparse(url) + + # If no authentication info, return as-is + if not parsed.username: + return url, None + + # Handle case where password might be empty string or None + password = parsed.password or "" + + # Extract credentials and create basic auth header (decode percent-encoding) + username = unquote(parsed.username) + password = unquote(password) + credentials = f"{username}:{password}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + headers = {"Authorization": f"Basic {encoded_credentials}"} + + # Reconstruct URL without authentication + cleaned_netloc = parsed.hostname or "" + if parsed.port: + cleaned_netloc += f":{parsed.port}" + + cleaned_url = f"{parsed.scheme}://{cleaned_netloc}" + if parsed.path: + cleaned_url += parsed.path + if parsed.query: + cleaned_url += f"?{parsed.query}" + if parsed.fragment: + cleaned_url += f"#{parsed.fragment}" + + return cleaned_url, headers diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index ae836ed5d08..20be441c833 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -7,19 +7,10 @@ import json import logging from collections.abc import AsyncIterator, Iterator, Mapping, Sequence from operator import itemgetter -from typing import ( - Any, - Callable, - Literal, - Optional, - Union, - cast, -) +from typing import Any, Callable, Literal, Optional, Union, cast from uuid import uuid4 -from langchain_core.callbacks import ( - CallbackManagerForLLMRun, -) +from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun from langchain_core.exceptions import OutputParserException from langchain_core.language_models import LanguageModelInput @@ -57,7 +48,7 @@ from pydantic.json_schema import JsonSchemaValue from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self, is_typeddict -from ._utils import validate_model +from ._utils import parse_url_with_auth, validate_model log = logging.getLogger(__name__) @@ -607,6 +598,15 @@ class ChatOllama(BaseChatModel): """Set clients to use for ollama.""" client_kwargs = self.client_kwargs or {} + # Parse URL for basic auth credentials + cleaned_url, auth_headers = parse_url_with_auth(self.base_url) + + # Merge authentication headers with existing headers + if auth_headers: + headers = client_kwargs.get("headers", {}) + headers.update(auth_headers) + client_kwargs = {**client_kwargs, "headers": headers} + sync_client_kwargs = client_kwargs if self.sync_client_kwargs: sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs} @@ -615,8 +615,8 @@ class ChatOllama(BaseChatModel): if self.async_client_kwargs: async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs} - self._client = Client(host=self.base_url, **sync_client_kwargs) - self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) + self._client = Client(host=cleaned_url, **sync_client_kwargs) + self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs) if self.validate_model_on_init: validate_model(self._client, self.model) return self diff --git a/libs/partners/ollama/langchain_ollama/embeddings.py b/libs/partners/ollama/langchain_ollama/embeddings.py index ac5619a3b06..08de9ca46db 100644 --- a/libs/partners/ollama/langchain_ollama/embeddings.py +++ b/libs/partners/ollama/langchain_ollama/embeddings.py @@ -6,15 +6,10 @@ from typing import Any, Optional from langchain_core.embeddings import Embeddings from ollama import AsyncClient, Client -from pydantic import ( - BaseModel, - ConfigDict, - PrivateAttr, - model_validator, -) +from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator from typing_extensions import Self -from ._utils import validate_model +from ._utils import parse_url_with_auth, validate_model class OllamaEmbeddings(BaseModel, Embeddings): @@ -260,6 +255,15 @@ class OllamaEmbeddings(BaseModel, Embeddings): """Set clients to use for ollama.""" client_kwargs = self.client_kwargs or {} + # Parse URL for basic auth credentials + cleaned_url, auth_headers = parse_url_with_auth(self.base_url) + + # Merge authentication headers with existing headers + if auth_headers: + headers = client_kwargs.get("headers", {}) + headers.update(auth_headers) + client_kwargs = {**client_kwargs, "headers": headers} + sync_client_kwargs = client_kwargs if self.sync_client_kwargs: sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs} @@ -268,8 +272,8 @@ class OllamaEmbeddings(BaseModel, Embeddings): if self.async_client_kwargs: async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs} - self._client = Client(host=self.base_url, **sync_client_kwargs) - self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) + self._client = Client(host=cleaned_url, **sync_client_kwargs) + self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs) if self.validate_model_on_init: validate_model(self._client, self.model) return self diff --git a/libs/partners/ollama/langchain_ollama/llms.py b/libs/partners/ollama/langchain_ollama/llms.py index b433606340d..b95537b0a20 100644 --- a/libs/partners/ollama/langchain_ollama/llms.py +++ b/libs/partners/ollama/langchain_ollama/llms.py @@ -3,12 +3,7 @@ from __future__ import annotations from collections.abc import AsyncIterator, Iterator, Mapping -from typing import ( - Any, - Literal, - Optional, - Union, -) +from typing import Any, Literal, Optional, Union from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -20,7 +15,7 @@ from ollama import AsyncClient, Client, Options from pydantic import PrivateAttr, model_validator from typing_extensions import Self -from ._utils import validate_model +from ._utils import parse_url_with_auth, validate_model class OllamaLLM(BaseLLM): @@ -230,6 +225,15 @@ class OllamaLLM(BaseLLM): """Set clients to use for ollama.""" client_kwargs = self.client_kwargs or {} + # Parse URL for basic auth credentials + cleaned_url, auth_headers = parse_url_with_auth(self.base_url) + + # Merge authentication headers with existing headers + if auth_headers: + headers = client_kwargs.get("headers", {}) + headers.update(auth_headers) + client_kwargs = {**client_kwargs, "headers": headers} + sync_client_kwargs = client_kwargs if self.sync_client_kwargs: sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs} @@ -238,8 +242,8 @@ class OllamaLLM(BaseLLM): if self.async_client_kwargs: async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs} - self._client = Client(host=self.base_url, **sync_client_kwargs) - self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) + self._client = Client(host=cleaned_url, **sync_client_kwargs) + self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs) if self.validate_model_on_init: validate_model(self._client, self.model) return self diff --git a/libs/partners/ollama/tests/unit_tests/test_url_auth.py b/libs/partners/ollama/tests/unit_tests/test_url_auth.py new file mode 100644 index 00000000000..2d4f50833b8 --- /dev/null +++ b/libs/partners/ollama/tests/unit_tests/test_url_auth.py @@ -0,0 +1,248 @@ +"""Test URL authentication parsing functionality.""" + +import base64 +from unittest.mock import MagicMock, patch + +from langchain_ollama._utils import parse_url_with_auth +from langchain_ollama.chat_models import ChatOllama +from langchain_ollama.embeddings import OllamaEmbeddings +from langchain_ollama.llms import OllamaLLM + + +class TestParseUrlWithAuth: + """Test the parse_url_with_auth utility function.""" + + def test_parse_url_with_auth_none_input(self) -> None: + """Test that None input returns None, None.""" + result = parse_url_with_auth(None) + assert result == (None, None) + + def test_parse_url_with_auth_no_credentials(self) -> None: + """Test URLs without authentication credentials.""" + url = "https://ollama.example.com:11434" + result = parse_url_with_auth(url) + assert result == (url, None) + + def test_parse_url_with_auth_with_credentials(self) -> None: + """Test URLs with authentication credentials.""" + url = "https://user:password@ollama.example.com:11434" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:password").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers + + def test_parse_url_with_auth_with_path_and_query(self) -> None: + """Test URLs with path and query parameters.""" + url = "https://user:pass@ollama.example.com:11434/api/v1?timeout=30" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com:11434/api/v1?timeout=30" + expected_credentials = base64.b64encode(b"user:pass").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers + + def test_parse_url_with_auth_special_characters(self) -> None: + """Test URLs with special characters in credentials.""" + url = "https://user%40domain:p%40ssw0rd@ollama.example.com:11434" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com:11434" + # Note: URL parsing handles percent-encoding automatically + expected_credentials = base64.b64encode(b"user@domain:p@ssw0rd").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers + + def test_parse_url_with_auth_only_username(self) -> None: + """Test URLs with only username (no password).""" + url = "https://user@ollama.example.com:11434" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers + + def test_parse_url_with_auth_empty_password(self) -> None: + """Test URLs with empty password.""" + url = "https://user:@ollama.example.com:11434" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers + + +class TestChatOllamaUrlAuth: + """Test URL authentication integration with ChatOllama.""" + + @patch("langchain_ollama.chat_models.Client") + @patch("langchain_ollama.chat_models.AsyncClient") + def test_chat_ollama_url_auth_integration( + self, mock_async_client: MagicMock, mock_client: MagicMock + ) -> None: + """Test that ChatOllama properly handles URL authentication.""" + url_with_auth = "https://user:password@ollama.example.com:11434" + + ChatOllama( + model="llama3", + base_url=url_with_auth, + ) + + # Verify the clients were called with cleaned URL and auth headers + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:password").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + mock_client.assert_called_once_with(host=expected_url, headers=expected_headers) + mock_async_client.assert_called_once_with( + host=expected_url, headers=expected_headers + ) + + @patch("langchain_ollama.chat_models.Client") + @patch("langchain_ollama.chat_models.AsyncClient") + def test_chat_ollama_url_auth_with_existing_headers( + self, mock_async_client: MagicMock, mock_client: MagicMock + ) -> None: + """Test that URL auth headers merge with existing headers.""" + url_with_auth = "https://user:password@ollama.example.com:11434" + existing_headers = {"User-Agent": "test-agent", "X-Custom": "value"} + + ChatOllama( + model="llama3", + base_url=url_with_auth, + client_kwargs={"headers": existing_headers}, + ) + + # Verify headers are merged + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:password").decode() + expected_headers = { + **existing_headers, + "Authorization": f"Basic {expected_credentials}", + } + + mock_client.assert_called_once_with(host=expected_url, headers=expected_headers) + mock_async_client.assert_called_once_with( + host=expected_url, headers=expected_headers + ) + + @patch("langchain_ollama.chat_models.Client") + @patch("langchain_ollama.chat_models.AsyncClient") + def test_chat_ollama_no_url_auth( + self, mock_async_client: MagicMock, mock_client: MagicMock + ) -> None: + """Test that ChatOllama works normally without URL authentication.""" + url_without_auth = "https://ollama.example.com:11434" + + ChatOllama( + model="llama3", + base_url=url_without_auth, + ) + + # Verify no auth headers are added + mock_client.assert_called_once_with(host=url_without_auth) + mock_async_client.assert_called_once_with(host=url_without_auth) + + +class TestOllamaLLMUrlAuth: + """Test URL authentication integration with OllamaLLM.""" + + @patch("langchain_ollama.llms.Client") + @patch("langchain_ollama.llms.AsyncClient") + def test_ollama_llm_url_auth_integration( + self, mock_async_client: MagicMock, mock_client: MagicMock + ) -> None: + """Test that OllamaLLM properly handles URL authentication.""" + url_with_auth = "https://user:password@ollama.example.com:11434" + + OllamaLLM( + model="llama3", + base_url=url_with_auth, + ) + + # Verify the clients were called with cleaned URL and auth headers + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:password").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + mock_client.assert_called_once_with(host=expected_url, headers=expected_headers) + mock_async_client.assert_called_once_with( + host=expected_url, headers=expected_headers + ) + + +class TestOllamaEmbeddingsUrlAuth: + """Test URL authentication integration with OllamaEmbeddings.""" + + @patch("langchain_ollama.embeddings.Client") + @patch("langchain_ollama.embeddings.AsyncClient") + def test_ollama_embeddings_url_auth_integration( + self, mock_async_client: MagicMock, mock_client: MagicMock + ) -> None: + """Test that OllamaEmbeddings properly handles URL authentication.""" + url_with_auth = "https://user:password@ollama.example.com:11434" + + OllamaEmbeddings( + model="llama3", + base_url=url_with_auth, + ) + + # Verify the clients were called with cleaned URL and auth headers + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:password").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + mock_client.assert_called_once_with(host=expected_url, headers=expected_headers) + mock_async_client.assert_called_once_with( + host=expected_url, headers=expected_headers + ) + + +class TestUrlAuthEdgeCases: + """Test edge cases and error conditions for URL authentication.""" + + def test_parse_url_with_auth_malformed_url(self) -> None: + """Test behavior with malformed URLs.""" + malformed_url = "not-a-valid-url" + result = parse_url_with_auth(malformed_url) + # Should return the original URL without modification + assert result == (malformed_url, None) + + def test_parse_url_with_auth_no_port(self) -> None: + """Test URLs without explicit port numbers.""" + url = "https://user:password@ollama.example.com" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com" + expected_credentials = base64.b64encode(b"user:password").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers + + def test_parse_url_with_auth_complex_password(self) -> None: + """Test with complex passwords containing special characters.""" + # Test password with colon, which is the delimiter + url = "https://user:pass:word@ollama.example.com:11434" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com:11434" + # The parser should handle the first colon as the separator + expected_credentials = base64.b64encode(b"user:pass:word").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers diff --git a/libs/partners/ollama/uv.lock b/libs/partners/ollama/uv.lock index bb34f95c883..83dd2baeedf 100644 --- a/libs/partners/ollama/uv.lock +++ b/libs/partners/ollama/uv.lock @@ -305,7 +305,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "0.3.70" +version = "0.3.72" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" },