From a89c549cb0860fcf443476b35d0069260c95da42 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Wed, 1 Oct 2025 20:46:37 -0400 Subject: [PATCH] feat(ollama): add basic auth support (#32328) support for URL authentication in the format `https://user:password@host:port` for all LangChain Ollama clients. Related to #32327 and #25055 --- .../ollama/langchain_ollama/_utils.py | 73 ++++++ .../ollama/langchain_ollama/chat_models.py | 58 +++-- .../ollama/langchain_ollama/embeddings.py | 52 ++-- libs/partners/ollama/langchain_ollama/llms.py | 52 ++-- .../ollama/tests/unit_tests/test_auth.py | 231 ++++++++++++++++++ libs/partners/ollama/uv.lock | 4 +- 6 files changed, 409 insertions(+), 61 deletions(-) create mode 100644 libs/partners/ollama/tests/unit_tests/test_auth.py diff --git a/libs/partners/ollama/langchain_ollama/_utils.py b/libs/partners/ollama/langchain_ollama/_utils.py index 8d08ed87a41..45e85e9932d 100644 --- a/libs/partners/ollama/langchain_ollama/_utils.py +++ b/libs/partners/ollama/langchain_ollama/_utils.py @@ -1,5 +1,11 @@ """Utility function to validate Ollama models.""" +from __future__ import annotations + +import base64 +from typing import Optional +from urllib.parse import unquote, urlparse + from httpx import ConnectError from ollama import Client, ResponseError @@ -40,3 +46,70 @@ def validate_model(client: Client, model_name: str) -> None: "Please check your Ollama server logs." ) raise ValueError(msg) from e + + +def parse_url_with_auth( + url: Optional[str], +) -> tuple[Optional[str], Optional[dict[str, str]]]: + """Parse URL and extract `userinfo` credentials for headers. + + Handles URLs of the form: `https://user:password@host:port/path` + + Args: + url: The URL to parse. + + Returns: + A tuple of ``(cleaned_url, headers_dict)`` where: + - ``cleaned_url`` is the URL without authentication credentials if any were + found. Otherwise, returns the original URL. + - ``headers_dict`` contains Authorization header if credentials were found. + """ + if not url: + return None, None + + parsed = urlparse(url) + if not parsed.scheme or not parsed.netloc or not parsed.hostname: + return None, None + if not parsed.username: + return url, None + + # Handle case where password might be empty string or None + password = parsed.password or "" + + # Create basic auth header (decode percent-encoding) + username = unquote(parsed.username) + password = unquote(password) + credentials = f"{username}:{password}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + headers = {"Authorization": f"Basic {encoded_credentials}"} + + # Strip credentials from URL + cleaned_netloc = parsed.hostname or "" + if parsed.port: + cleaned_netloc += f":{parsed.port}" + + cleaned_url = f"{parsed.scheme}://{cleaned_netloc}" + if parsed.path: + cleaned_url += parsed.path + if parsed.query: + cleaned_url += f"?{parsed.query}" + if parsed.fragment: + cleaned_url += f"#{parsed.fragment}" + + return cleaned_url, headers + + +def merge_auth_headers( + client_kwargs: dict, + auth_headers: Optional[dict[str, str]], +) -> None: + """Merge authentication headers into client kwargs in-place. + + Args: + client_kwargs: The client kwargs dict to update. + auth_headers: Headers to merge (typically from ``parse_url_with_auth``). + """ + if auth_headers: + headers = client_kwargs.get("headers", {}) + headers.update(auth_headers) + client_kwargs["headers"] = headers diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index 6d7dc66373a..aebc70b7f3b 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -7,19 +7,10 @@ import json import logging from collections.abc import AsyncIterator, Iterator, Mapping, Sequence from operator import itemgetter -from typing import ( - Any, - Callable, - Literal, - Optional, - Union, - cast, -) +from typing import Any, Callable, Literal, Optional, Union, cast from uuid import uuid4 -from langchain_core.callbacks import ( - CallbackManagerForLLMRun, -) +from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun from langchain_core.exceptions import OutputParserException from langchain_core.language_models import LanguageModelInput @@ -57,7 +48,7 @@ from pydantic.json_schema import JsonSchemaValue from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self, is_typeddict -from ._utils import validate_model +from ._utils import merge_auth_headers, parse_url_with_auth, validate_model log = logging.getLogger(__name__) @@ -592,32 +583,50 @@ class ChatOllama(BaseChatModel): """How long the model will stay loaded into memory.""" base_url: Optional[str] = None - """Base url the model is hosted under.""" + """Base url the model is hosted under. + + If none, defaults to the Ollama client default. + + Supports `userinfo` auth in the format `http://username:password@localhost:11434`. + Useful if your Ollama server is behind a proxy. + + !!! warning + `userinfo` is not secure and should only be used for local testing or + in secure environments. Avoid using it in production or over unsecured + networks. + + !!! note + If using `userinfo`, ensure that the Ollama server is configured to + accept and validate these credentials. + + !!! note + `userinfo` headers are passed to both sync and async clients. + + """ client_kwargs: Optional[dict] = {} - """Additional kwargs to pass to the httpx clients. + """Additional kwargs to pass to the httpx clients. Pass headers in here. These arguments are passed to both synchronous and async clients. Use ``sync_client_kwargs`` and ``async_client_kwargs`` to pass different arguments to synchronous and asynchronous clients. - """ async_client_kwargs: Optional[dict] = {} - """Additional kwargs to merge with ``client_kwargs`` before - passing to the httpx AsyncClient. + """Additional kwargs to merge with ``client_kwargs`` before passing to httpx client. - `Full list of params. `__ + These are clients unique to the async client; for shared args use ``client_kwargs``. + For a full list of the params, see the `httpx documentation `__. """ sync_client_kwargs: Optional[dict] = {} - """Additional kwargs to merge with ``client_kwargs`` before - passing to the httpx Client. + """Additional kwargs to merge with ``client_kwargs`` before passing to httpx client. - `Full list of params. `__ + These are clients unique to the sync client; for shared args use ``client_kwargs``. + For a full list of the params, see the `httpx documentation `__. """ _client: Client = PrivateAttr() @@ -682,6 +691,9 @@ class ChatOllama(BaseChatModel): """Set clients to use for ollama.""" client_kwargs = self.client_kwargs or {} + cleaned_url, auth_headers = parse_url_with_auth(self.base_url) + merge_auth_headers(client_kwargs, auth_headers) + sync_client_kwargs = client_kwargs if self.sync_client_kwargs: sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs} @@ -690,8 +702,8 @@ class ChatOllama(BaseChatModel): if self.async_client_kwargs: async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs} - self._client = Client(host=self.base_url, **sync_client_kwargs) - self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) + self._client = Client(host=cleaned_url, **sync_client_kwargs) + self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs) if self.validate_model_on_init: validate_model(self._client, self.model) return self diff --git a/libs/partners/ollama/langchain_ollama/embeddings.py b/libs/partners/ollama/langchain_ollama/embeddings.py index 97f49818f00..8cdf3a6e169 100644 --- a/libs/partners/ollama/langchain_ollama/embeddings.py +++ b/libs/partners/ollama/langchain_ollama/embeddings.py @@ -6,15 +6,10 @@ from typing import Any, Optional from langchain_core.embeddings import Embeddings from ollama import AsyncClient, Client -from pydantic import ( - BaseModel, - ConfigDict, - PrivateAttr, - model_validator, -) +from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator from typing_extensions import Self -from ._utils import validate_model +from ._utils import merge_auth_headers, parse_url_with_auth, validate_model class OllamaEmbeddings(BaseModel, Embeddings): @@ -134,32 +129,50 @@ class OllamaEmbeddings(BaseModel, Embeddings): """ base_url: Optional[str] = None - """Base url the model is hosted under.""" + """Base url the model is hosted under. + + If none, defaults to the Ollama client default. + + Supports `userinfo` auth in the format `http://username:password@localhost:11434`. + Useful if your Ollama server is behind a proxy. + + !!! warning + `userinfo` is not secure and should only be used for local testing or + in secure environments. Avoid using it in production or over unsecured + networks. + + !!! note + If using `userinfo`, ensure that the Ollama server is configured to + accept and validate these credentials. + + !!! note + `userinfo` headers are passed to both sync and async clients. + + """ client_kwargs: Optional[dict] = {} - """Additional kwargs to pass to the httpx clients. + """Additional kwargs to pass to the httpx clients. Pass headers in here. These arguments are passed to both synchronous and async clients. Use ``sync_client_kwargs`` and ``async_client_kwargs`` to pass different arguments to synchronous and asynchronous clients. - """ async_client_kwargs: Optional[dict] = {} - """Additional kwargs to merge with ``client_kwargs`` before passing to the httpx - AsyncClient. + """Additional kwargs to merge with ``client_kwargs`` before passing to httpx client. - For a full list of the params, see the `HTTPX documentation `__. + These are clients unique to the async client; for shared args use ``client_kwargs``. + For a full list of the params, see the `httpx documentation `__. """ sync_client_kwargs: Optional[dict] = {} - """Additional kwargs to merge with ``client_kwargs`` before - passing to the HTTPX Client. + """Additional kwargs to merge with ``client_kwargs`` before passing to httpx client. - For a full list of the params, see the `HTTPX documentation `__. + These are clients unique to the sync client; for shared args use ``client_kwargs``. + For a full list of the params, see the `httpx documentation `__. """ _client: Optional[Client] = PrivateAttr(default=None) @@ -261,6 +274,9 @@ class OllamaEmbeddings(BaseModel, Embeddings): """Set clients to use for Ollama.""" client_kwargs = self.client_kwargs or {} + cleaned_url, auth_headers = parse_url_with_auth(self.base_url) + merge_auth_headers(client_kwargs, auth_headers) + sync_client_kwargs = client_kwargs if self.sync_client_kwargs: sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs} @@ -269,8 +285,8 @@ class OllamaEmbeddings(BaseModel, Embeddings): if self.async_client_kwargs: async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs} - self._client = Client(host=self.base_url, **sync_client_kwargs) - self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) + self._client = Client(host=cleaned_url, **sync_client_kwargs) + self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs) if self.validate_model_on_init: validate_model(self._client, self.model) return self diff --git a/libs/partners/ollama/langchain_ollama/llms.py b/libs/partners/ollama/langchain_ollama/llms.py index 7787d86da4d..06f91502bab 100644 --- a/libs/partners/ollama/langchain_ollama/llms.py +++ b/libs/partners/ollama/langchain_ollama/llms.py @@ -3,12 +3,7 @@ from __future__ import annotations from collections.abc import AsyncIterator, Iterator, Mapping -from typing import ( - Any, - Literal, - Optional, - Union, -) +from typing import Any, Literal, Optional, Union from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -20,7 +15,7 @@ from ollama import AsyncClient, Client, Options from pydantic import PrivateAttr, model_validator from typing_extensions import Self -from ._utils import validate_model +from ._utils import merge_auth_headers, parse_url_with_auth, validate_model class OllamaLLM(BaseLLM): @@ -213,32 +208,50 @@ class OllamaLLM(BaseLLM): """How long the model will stay loaded into memory.""" base_url: Optional[str] = None - """Base url the model is hosted under.""" + """Base url the model is hosted under. + + If none, defaults to the Ollama client default. + + Supports `userinfo` auth in the format `http://username:password@localhost:11434`. + Useful if your Ollama server is behind a proxy. + + !!! warning + `userinfo` is not secure and should only be used for local testing or + in secure environments. Avoid using it in production or over unsecured + networks. + + !!! note + If using `userinfo`, ensure that the Ollama server is configured to + accept and validate these credentials. + + !!! note + `userinfo` headers are passed to both sync and async clients. + + """ client_kwargs: Optional[dict] = {} - """Additional kwargs to pass to the httpx clients. + """Additional kwargs to pass to the httpx clients. Pass headers in here. These arguments are passed to both synchronous and async clients. Use ``sync_client_kwargs`` and ``async_client_kwargs`` to pass different arguments to synchronous and asynchronous clients. - """ async_client_kwargs: Optional[dict] = {} - """Additional kwargs to merge with ``client_kwargs`` before passing to the HTTPX - AsyncClient. + """Additional kwargs to merge with ``client_kwargs`` before passing to httpx client. - For a full list of the params, see the `HTTPX documentation `__. + These are clients unique to the async client; for shared args use ``client_kwargs``. + For a full list of the params, see the `httpx documentation `__. """ sync_client_kwargs: Optional[dict] = {} - """Additional kwargs to merge with ``client_kwargs`` before - passing to the HTTPX Client. + """Additional kwargs to merge with ``client_kwargs`` before passing to httpx client. - For a full list of the params, see the `HTTPX documentation `__. + These are clients unique to the sync client; for shared args use ``client_kwargs``. + For a full list of the params, see the `httpx documentation `__. """ _client: Optional[Client] = PrivateAttr(default=None) @@ -310,6 +323,9 @@ class OllamaLLM(BaseLLM): """Set clients to use for ollama.""" client_kwargs = self.client_kwargs or {} + cleaned_url, auth_headers = parse_url_with_auth(self.base_url) + merge_auth_headers(client_kwargs, auth_headers) + sync_client_kwargs = client_kwargs if self.sync_client_kwargs: sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs} @@ -318,8 +334,8 @@ class OllamaLLM(BaseLLM): if self.async_client_kwargs: async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs} - self._client = Client(host=self.base_url, **sync_client_kwargs) - self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs) + self._client = Client(host=cleaned_url, **sync_client_kwargs) + self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs) if self.validate_model_on_init: validate_model(self._client, self.model) return self diff --git a/libs/partners/ollama/tests/unit_tests/test_auth.py b/libs/partners/ollama/tests/unit_tests/test_auth.py new file mode 100644 index 00000000000..1db36d46d88 --- /dev/null +++ b/libs/partners/ollama/tests/unit_tests/test_auth.py @@ -0,0 +1,231 @@ +"""Test URL authentication parsing functionality.""" + +import base64 +from unittest.mock import MagicMock, patch + +from langchain_ollama._utils import parse_url_with_auth +from langchain_ollama.chat_models import ChatOllama +from langchain_ollama.embeddings import OllamaEmbeddings +from langchain_ollama.llms import OllamaLLM + +MODEL_NAME = "llama3.1" + + +class TestParseUrlWithAuth: + """Test the parse_url_with_auth utility function.""" + + def test_parse_url_with_auth_none_input(self) -> None: + """Test that None input returns None, None.""" + result = parse_url_with_auth(None) + assert result == (None, None) + + def test_parse_url_with_auth_no_credentials(self) -> None: + """Test URLs without authentication credentials.""" + url = "https://ollama.example.com:11434/path?query=param" + result = parse_url_with_auth(url) + assert result == (url, None) + + def test_parse_url_with_auth_with_credentials(self) -> None: + """Test URLs with authentication credentials.""" + url = "https://user:password@ollama.example.com:11434" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:password").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers + + def test_parse_url_with_auth_with_path_and_query(self) -> None: + """Test URLs with auth, path, and query parameters.""" + url = "https://user:pass@ollama.example.com:11434/api/v1?timeout=30" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com:11434/api/v1?timeout=30" + expected_credentials = base64.b64encode(b"user:pass").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers + + def test_parse_url_with_auth_special_characters(self) -> None: + """Test URLs with special characters in credentials.""" + url = "https://user%40domain:p%40ssw0rd@ollama.example.com:11434" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com:11434" + # Note: URL parsing handles percent-encoding automatically + expected_credentials = base64.b64encode(b"user@domain:p@ssw0rd").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers + + def test_parse_url_with_auth_only_username(self) -> None: + """Test URLs with only username (no password).""" + url = "https://user@ollama.example.com:11434" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers + + def test_parse_url_with_auth_empty_password(self) -> None: + """Test URLs with empty password.""" + url = "https://user:@ollama.example.com:11434" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers + + +class TestChatOllamaUrlAuth: + """Test URL authentication integration with ChatOllama.""" + + @patch("langchain_ollama.chat_models.Client") + @patch("langchain_ollama.chat_models.AsyncClient") + def test_chat_ollama_url_auth_integration( + self, mock_async_client: MagicMock, mock_client: MagicMock + ) -> None: + """Test that ChatOllama properly handles URL authentication.""" + url_with_auth = "https://user:password@ollama.example.com:11434" + + ChatOllama( + model=MODEL_NAME, + base_url=url_with_auth, + ) + + # Verify the clients were called with cleaned URL and auth headers + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:password").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + mock_client.assert_called_once_with(host=expected_url, headers=expected_headers) + mock_async_client.assert_called_once_with( + host=expected_url, headers=expected_headers + ) + + @patch("langchain_ollama.chat_models.Client") + @patch("langchain_ollama.chat_models.AsyncClient") + def test_chat_ollama_url_auth_with_existing_headers( + self, mock_async_client: MagicMock, mock_client: MagicMock + ) -> None: + """Test that URL auth headers merge with existing headers.""" + url_with_auth = "https://user:password@ollama.example.com:11434" + existing_headers = {"User-Agent": "test-agent", "X-Custom": "value"} + + ChatOllama( + model=MODEL_NAME, + base_url=url_with_auth, + client_kwargs={"headers": existing_headers}, + ) + + # Verify headers are merged + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:password").decode() + expected_headers = { + **existing_headers, + "Authorization": f"Basic {expected_credentials}", + } + + mock_client.assert_called_once_with(host=expected_url, headers=expected_headers) + mock_async_client.assert_called_once_with( + host=expected_url, headers=expected_headers + ) + + +class TestOllamaLLMUrlAuth: + """Test URL authentication integration with OllamaLLM.""" + + @patch("langchain_ollama.llms.Client") + @patch("langchain_ollama.llms.AsyncClient") + def test_ollama_llm_url_auth_integration( + self, mock_async_client: MagicMock, mock_client: MagicMock + ) -> None: + """Test that OllamaLLM properly handles URL authentication.""" + url_with_auth = "https://user:password@ollama.example.com:11434" + + OllamaLLM( + model=MODEL_NAME, + base_url=url_with_auth, + ) + + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:password").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + mock_client.assert_called_once_with(host=expected_url, headers=expected_headers) + mock_async_client.assert_called_once_with( + host=expected_url, headers=expected_headers + ) + + +class TestOllamaEmbeddingsUrlAuth: + """Test URL authentication integration with OllamaEmbeddings.""" + + @patch("langchain_ollama.embeddings.Client") + @patch("langchain_ollama.embeddings.AsyncClient") + def test_ollama_embeddings_url_auth_integration( + self, mock_async_client: MagicMock, mock_client: MagicMock + ) -> None: + """Test that OllamaEmbeddings properly handles URL authentication.""" + url_with_auth = "https://user:password@ollama.example.com:11434" + + OllamaEmbeddings( + model=MODEL_NAME, + base_url=url_with_auth, + ) + + expected_url = "https://ollama.example.com:11434" + expected_credentials = base64.b64encode(b"user:password").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + mock_client.assert_called_once_with(host=expected_url, headers=expected_headers) + mock_async_client.assert_called_once_with( + host=expected_url, headers=expected_headers + ) + + +class TestUrlAuthEdgeCases: + """Test edge cases and error conditions for URL authentication.""" + + def test_parse_url_with_auth_malformed_url(self) -> None: + """Test behavior with malformed URLs.""" + malformed_url = "not-a-valid-url" + result = parse_url_with_auth(malformed_url) + # Shouldn't return a URL as it wouldn't parse correctly or reach a server + assert result == (None, None) + + def test_parse_url_with_auth_no_port(self) -> None: + """Test URLs without explicit port numbers.""" + url = "https://user:password@ollama.example.com" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com" + expected_credentials = base64.b64encode(b"user:password").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers + + def test_parse_url_with_auth_complex_password(self) -> None: + """Test with complex passwords containing special characters.""" + # Test password with colon, which is the delimiter + url = "https://user:pass:word@ollama.example.com:11434" + cleaned_url, headers = parse_url_with_auth(url) + + expected_url = "https://ollama.example.com:11434" + # The parser should handle the first colon as the separator + expected_credentials = base64.b64encode(b"user:pass:word").decode() + expected_headers = {"Authorization": f"Basic {expected_credentials}"} + + assert cleaned_url == expected_url + assert headers == expected_headers diff --git a/libs/partners/ollama/uv.lock b/libs/partners/ollama/uv.lock index 027f4797b42..148d9b2b585 100644 --- a/libs/partners/ollama/uv.lock +++ b/libs/partners/ollama/uv.lock @@ -323,7 +323,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "0.3.76" +version = "0.3.77" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, @@ -435,7 +435,7 @@ typing = [ [[package]] name = "langchain-tests" -version = "0.3.21" +version = "0.3.22" source = { editable = "../../standard-tests" } dependencies = [ { name = "httpx" },