mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 05:45:01 +00:00
feat(ollama): add basic auth support
This commit is contained in:
parent
a9e52ca605
commit
19a0761b99
@ -1,5 +1,11 @@
|
||||
"""Utility functions for validating Ollama models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Optional
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from httpx import ConnectError
|
||||
from ollama import Client, ResponseError
|
||||
|
||||
@ -37,3 +43,51 @@ def validate_model(client: Client, model_name: str) -> None:
|
||||
"Please check your Ollama server logs."
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
|
||||
def parse_url_with_auth(url: Optional[str]) -> tuple[Optional[str], Optional[dict]]:
|
||||
"""Parse URL and extract authentication credentials for headers.
|
||||
|
||||
Handles URLs of the form: ``https://user:password@host:port/path``
|
||||
|
||||
Args:
|
||||
url: The URL to parse. Can be None.
|
||||
|
||||
Returns:
|
||||
A tuple of ``(cleaned_url, headers_dict)`` where:
|
||||
- ``cleaned_url`` is the URL without authentication credentials
|
||||
- ``headers_dict`` contains Authorization header if credentials were found
|
||||
"""
|
||||
if not url:
|
||||
return None, None
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
# If no authentication info, return as-is
|
||||
if not parsed.username:
|
||||
return url, None
|
||||
|
||||
# Handle case where password might be empty string or None
|
||||
password = parsed.password or ""
|
||||
|
||||
# Extract credentials and create basic auth header (decode percent-encoding)
|
||||
username = unquote(parsed.username)
|
||||
password = unquote(password)
|
||||
credentials = f"{username}:{password}"
|
||||
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
||||
headers = {"Authorization": f"Basic {encoded_credentials}"}
|
||||
|
||||
# Reconstruct URL without authentication
|
||||
cleaned_netloc = parsed.hostname or ""
|
||||
if parsed.port:
|
||||
cleaned_netloc += f":{parsed.port}"
|
||||
|
||||
cleaned_url = f"{parsed.scheme}://{cleaned_netloc}"
|
||||
if parsed.path:
|
||||
cleaned_url += parsed.path
|
||||
if parsed.query:
|
||||
cleaned_url += f"?{parsed.query}"
|
||||
if parsed.fragment:
|
||||
cleaned_url += f"#{parsed.fragment}"
|
||||
|
||||
return cleaned_url, headers
|
||||
|
@ -7,19 +7,10 @@ import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import Any, Callable, Literal, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
@ -57,7 +48,7 @@ from pydantic.json_schema import JsonSchemaValue
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from typing_extensions import Self, is_typeddict
|
||||
|
||||
from ._utils import validate_model
|
||||
from ._utils import parse_url_with_auth, validate_model
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -607,6 +598,15 @@ class ChatOllama(BaseChatModel):
|
||||
"""Set clients to use for ollama."""
|
||||
client_kwargs = self.client_kwargs or {}
|
||||
|
||||
# Parse URL for basic auth credentials
|
||||
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
|
||||
|
||||
# Merge authentication headers with existing headers
|
||||
if auth_headers:
|
||||
headers = client_kwargs.get("headers", {})
|
||||
headers.update(auth_headers)
|
||||
client_kwargs = {**client_kwargs, "headers": headers}
|
||||
|
||||
sync_client_kwargs = client_kwargs
|
||||
if self.sync_client_kwargs:
|
||||
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
|
||||
@ -615,8 +615,8 @@ class ChatOllama(BaseChatModel):
|
||||
if self.async_client_kwargs:
|
||||
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
|
||||
|
||||
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
||||
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
||||
self._client = Client(host=cleaned_url, **sync_client_kwargs)
|
||||
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
|
||||
if self.validate_model_on_init:
|
||||
validate_model(self._client, self.model)
|
||||
return self
|
||||
|
@ -6,15 +6,10 @@ from typing import Any, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from ollama import AsyncClient, Client
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._utils import validate_model
|
||||
from ._utils import parse_url_with_auth, validate_model
|
||||
|
||||
|
||||
class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
@ -260,6 +255,15 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
"""Set clients to use for ollama."""
|
||||
client_kwargs = self.client_kwargs or {}
|
||||
|
||||
# Parse URL for basic auth credentials
|
||||
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
|
||||
|
||||
# Merge authentication headers with existing headers
|
||||
if auth_headers:
|
||||
headers = client_kwargs.get("headers", {})
|
||||
headers.update(auth_headers)
|
||||
client_kwargs = {**client_kwargs, "headers": headers}
|
||||
|
||||
sync_client_kwargs = client_kwargs
|
||||
if self.sync_client_kwargs:
|
||||
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
|
||||
@ -268,8 +272,8 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
if self.async_client_kwargs:
|
||||
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
|
||||
|
||||
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
||||
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
||||
self._client = Client(host=cleaned_url, **sync_client_kwargs)
|
||||
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
|
||||
if self.validate_model_on_init:
|
||||
validate_model(self._client, self.model)
|
||||
return self
|
||||
|
@ -3,12 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -20,7 +15,7 @@ from ollama import AsyncClient, Client, Options
|
||||
from pydantic import PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._utils import validate_model
|
||||
from ._utils import parse_url_with_auth, validate_model
|
||||
|
||||
|
||||
class OllamaLLM(BaseLLM):
|
||||
@ -230,6 +225,15 @@ class OllamaLLM(BaseLLM):
|
||||
"""Set clients to use for ollama."""
|
||||
client_kwargs = self.client_kwargs or {}
|
||||
|
||||
# Parse URL for basic auth credentials
|
||||
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
|
||||
|
||||
# Merge authentication headers with existing headers
|
||||
if auth_headers:
|
||||
headers = client_kwargs.get("headers", {})
|
||||
headers.update(auth_headers)
|
||||
client_kwargs = {**client_kwargs, "headers": headers}
|
||||
|
||||
sync_client_kwargs = client_kwargs
|
||||
if self.sync_client_kwargs:
|
||||
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
|
||||
@ -238,8 +242,8 @@ class OllamaLLM(BaseLLM):
|
||||
if self.async_client_kwargs:
|
||||
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
|
||||
|
||||
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
||||
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
||||
self._client = Client(host=cleaned_url, **sync_client_kwargs)
|
||||
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
|
||||
if self.validate_model_on_init:
|
||||
validate_model(self._client, self.model)
|
||||
return self
|
||||
|
248
libs/partners/ollama/tests/unit_tests/test_url_auth.py
Normal file
248
libs/partners/ollama/tests/unit_tests/test_url_auth.py
Normal file
@ -0,0 +1,248 @@
|
||||
"""Test URL authentication parsing functionality."""
|
||||
|
||||
import base64
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_ollama._utils import parse_url_with_auth
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
from langchain_ollama.embeddings import OllamaEmbeddings
|
||||
from langchain_ollama.llms import OllamaLLM
|
||||
|
||||
|
||||
class TestParseUrlWithAuth:
|
||||
"""Test the parse_url_with_auth utility function."""
|
||||
|
||||
def test_parse_url_with_auth_none_input(self) -> None:
|
||||
"""Test that None input returns None, None."""
|
||||
result = parse_url_with_auth(None)
|
||||
assert result == (None, None)
|
||||
|
||||
def test_parse_url_with_auth_no_credentials(self) -> None:
|
||||
"""Test URLs without authentication credentials."""
|
||||
url = "https://ollama.example.com:11434"
|
||||
result = parse_url_with_auth(url)
|
||||
assert result == (url, None)
|
||||
|
||||
def test_parse_url_with_auth_with_credentials(self) -> None:
|
||||
"""Test URLs with authentication credentials."""
|
||||
url = "https://user:password@ollama.example.com:11434"
|
||||
cleaned_url, headers = parse_url_with_auth(url)
|
||||
|
||||
expected_url = "https://ollama.example.com:11434"
|
||||
expected_credentials = base64.b64encode(b"user:password").decode()
|
||||
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||
|
||||
assert cleaned_url == expected_url
|
||||
assert headers == expected_headers
|
||||
|
||||
def test_parse_url_with_auth_with_path_and_query(self) -> None:
|
||||
"""Test URLs with path and query parameters."""
|
||||
url = "https://user:pass@ollama.example.com:11434/api/v1?timeout=30"
|
||||
cleaned_url, headers = parse_url_with_auth(url)
|
||||
|
||||
expected_url = "https://ollama.example.com:11434/api/v1?timeout=30"
|
||||
expected_credentials = base64.b64encode(b"user:pass").decode()
|
||||
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||
|
||||
assert cleaned_url == expected_url
|
||||
assert headers == expected_headers
|
||||
|
||||
def test_parse_url_with_auth_special_characters(self) -> None:
|
||||
"""Test URLs with special characters in credentials."""
|
||||
url = "https://user%40domain:p%40ssw0rd@ollama.example.com:11434"
|
||||
cleaned_url, headers = parse_url_with_auth(url)
|
||||
|
||||
expected_url = "https://ollama.example.com:11434"
|
||||
# Note: URL parsing handles percent-encoding automatically
|
||||
expected_credentials = base64.b64encode(b"user@domain:p@ssw0rd").decode()
|
||||
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||
|
||||
assert cleaned_url == expected_url
|
||||
assert headers == expected_headers
|
||||
|
||||
def test_parse_url_with_auth_only_username(self) -> None:
|
||||
"""Test URLs with only username (no password)."""
|
||||
url = "https://user@ollama.example.com:11434"
|
||||
cleaned_url, headers = parse_url_with_auth(url)
|
||||
|
||||
expected_url = "https://ollama.example.com:11434"
|
||||
expected_credentials = base64.b64encode(b"user:").decode()
|
||||
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||
|
||||
assert cleaned_url == expected_url
|
||||
assert headers == expected_headers
|
||||
|
||||
def test_parse_url_with_auth_empty_password(self) -> None:
|
||||
"""Test URLs with empty password."""
|
||||
url = "https://user:@ollama.example.com:11434"
|
||||
cleaned_url, headers = parse_url_with_auth(url)
|
||||
|
||||
expected_url = "https://ollama.example.com:11434"
|
||||
expected_credentials = base64.b64encode(b"user:").decode()
|
||||
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||
|
||||
assert cleaned_url == expected_url
|
||||
assert headers == expected_headers
|
||||
|
||||
|
||||
class TestChatOllamaUrlAuth:
|
||||
"""Test URL authentication integration with ChatOllama."""
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client")
|
||||
@patch("langchain_ollama.chat_models.AsyncClient")
|
||||
def test_chat_ollama_url_auth_integration(
|
||||
self, mock_async_client: MagicMock, mock_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that ChatOllama properly handles URL authentication."""
|
||||
url_with_auth = "https://user:password@ollama.example.com:11434"
|
||||
|
||||
ChatOllama(
|
||||
model="llama3",
|
||||
base_url=url_with_auth,
|
||||
)
|
||||
|
||||
# Verify the clients were called with cleaned URL and auth headers
|
||||
expected_url = "https://ollama.example.com:11434"
|
||||
expected_credentials = base64.b64encode(b"user:password").decode()
|
||||
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||
|
||||
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
|
||||
mock_async_client.assert_called_once_with(
|
||||
host=expected_url, headers=expected_headers
|
||||
)
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client")
|
||||
@patch("langchain_ollama.chat_models.AsyncClient")
|
||||
def test_chat_ollama_url_auth_with_existing_headers(
|
||||
self, mock_async_client: MagicMock, mock_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that URL auth headers merge with existing headers."""
|
||||
url_with_auth = "https://user:password@ollama.example.com:11434"
|
||||
existing_headers = {"User-Agent": "test-agent", "X-Custom": "value"}
|
||||
|
||||
ChatOllama(
|
||||
model="llama3",
|
||||
base_url=url_with_auth,
|
||||
client_kwargs={"headers": existing_headers},
|
||||
)
|
||||
|
||||
# Verify headers are merged
|
||||
expected_url = "https://ollama.example.com:11434"
|
||||
expected_credentials = base64.b64encode(b"user:password").decode()
|
||||
expected_headers = {
|
||||
**existing_headers,
|
||||
"Authorization": f"Basic {expected_credentials}",
|
||||
}
|
||||
|
||||
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
|
||||
mock_async_client.assert_called_once_with(
|
||||
host=expected_url, headers=expected_headers
|
||||
)
|
||||
|
||||
@patch("langchain_ollama.chat_models.Client")
|
||||
@patch("langchain_ollama.chat_models.AsyncClient")
|
||||
def test_chat_ollama_no_url_auth(
|
||||
self, mock_async_client: MagicMock, mock_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that ChatOllama works normally without URL authentication."""
|
||||
url_without_auth = "https://ollama.example.com:11434"
|
||||
|
||||
ChatOllama(
|
||||
model="llama3",
|
||||
base_url=url_without_auth,
|
||||
)
|
||||
|
||||
# Verify no auth headers are added
|
||||
mock_client.assert_called_once_with(host=url_without_auth)
|
||||
mock_async_client.assert_called_once_with(host=url_without_auth)
|
||||
|
||||
|
||||
class TestOllamaLLMUrlAuth:
|
||||
"""Test URL authentication integration with OllamaLLM."""
|
||||
|
||||
@patch("langchain_ollama.llms.Client")
|
||||
@patch("langchain_ollama.llms.AsyncClient")
|
||||
def test_ollama_llm_url_auth_integration(
|
||||
self, mock_async_client: MagicMock, mock_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that OllamaLLM properly handles URL authentication."""
|
||||
url_with_auth = "https://user:password@ollama.example.com:11434"
|
||||
|
||||
OllamaLLM(
|
||||
model="llama3",
|
||||
base_url=url_with_auth,
|
||||
)
|
||||
|
||||
# Verify the clients were called with cleaned URL and auth headers
|
||||
expected_url = "https://ollama.example.com:11434"
|
||||
expected_credentials = base64.b64encode(b"user:password").decode()
|
||||
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||
|
||||
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
|
||||
mock_async_client.assert_called_once_with(
|
||||
host=expected_url, headers=expected_headers
|
||||
)
|
||||
|
||||
|
||||
class TestOllamaEmbeddingsUrlAuth:
|
||||
"""Test URL authentication integration with OllamaEmbeddings."""
|
||||
|
||||
@patch("langchain_ollama.embeddings.Client")
|
||||
@patch("langchain_ollama.embeddings.AsyncClient")
|
||||
def test_ollama_embeddings_url_auth_integration(
|
||||
self, mock_async_client: MagicMock, mock_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that OllamaEmbeddings properly handles URL authentication."""
|
||||
url_with_auth = "https://user:password@ollama.example.com:11434"
|
||||
|
||||
OllamaEmbeddings(
|
||||
model="llama3",
|
||||
base_url=url_with_auth,
|
||||
)
|
||||
|
||||
# Verify the clients were called with cleaned URL and auth headers
|
||||
expected_url = "https://ollama.example.com:11434"
|
||||
expected_credentials = base64.b64encode(b"user:password").decode()
|
||||
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||
|
||||
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
|
||||
mock_async_client.assert_called_once_with(
|
||||
host=expected_url, headers=expected_headers
|
||||
)
|
||||
|
||||
|
||||
class TestUrlAuthEdgeCases:
|
||||
"""Test edge cases and error conditions for URL authentication."""
|
||||
|
||||
def test_parse_url_with_auth_malformed_url(self) -> None:
|
||||
"""Test behavior with malformed URLs."""
|
||||
malformed_url = "not-a-valid-url"
|
||||
result = parse_url_with_auth(malformed_url)
|
||||
# Should return the original URL without modification
|
||||
assert result == (malformed_url, None)
|
||||
|
||||
def test_parse_url_with_auth_no_port(self) -> None:
|
||||
"""Test URLs without explicit port numbers."""
|
||||
url = "https://user:password@ollama.example.com"
|
||||
cleaned_url, headers = parse_url_with_auth(url)
|
||||
|
||||
expected_url = "https://ollama.example.com"
|
||||
expected_credentials = base64.b64encode(b"user:password").decode()
|
||||
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||
|
||||
assert cleaned_url == expected_url
|
||||
assert headers == expected_headers
|
||||
|
||||
def test_parse_url_with_auth_complex_password(self) -> None:
|
||||
"""Test with complex passwords containing special characters."""
|
||||
# Test password with colon, which is the delimiter
|
||||
url = "https://user:pass:word@ollama.example.com:11434"
|
||||
cleaned_url, headers = parse_url_with_auth(url)
|
||||
|
||||
expected_url = "https://ollama.example.com:11434"
|
||||
# The parser should handle the first colon as the separator
|
||||
expected_credentials = base64.b64encode(b"user:pass:word").decode()
|
||||
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||
|
||||
assert cleaned_url == expected_url
|
||||
assert headers == expected_headers
|
@ -305,7 +305,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.3.70"
|
||||
version = "0.3.72"
|
||||
source = { editable = "../../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
|
Loading…
Reference in New Issue
Block a user