mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 06:13:36 +00:00
feat(ollama): add basic auth support
This commit is contained in:
parent
a9e52ca605
commit
19a0761b99
@ -1,5 +1,11 @@
|
|||||||
"""Utility functions for validating Ollama models."""
|
"""Utility functions for validating Ollama models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import unquote, urlparse
|
||||||
|
|
||||||
from httpx import ConnectError
|
from httpx import ConnectError
|
||||||
from ollama import Client, ResponseError
|
from ollama import Client, ResponseError
|
||||||
|
|
||||||
@ -37,3 +43,51 @@ def validate_model(client: Client, model_name: str) -> None:
|
|||||||
"Please check your Ollama server logs."
|
"Please check your Ollama server logs."
|
||||||
)
|
)
|
||||||
raise ValueError(msg) from e
|
raise ValueError(msg) from e
|
||||||
|
|
||||||
|
|
||||||
|
def parse_url_with_auth(url: Optional[str]) -> tuple[Optional[str], Optional[dict]]:
|
||||||
|
"""Parse URL and extract authentication credentials for headers.
|
||||||
|
|
||||||
|
Handles URLs of the form: ``https://user:password@host:port/path``
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to parse. Can be None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of ``(cleaned_url, headers_dict)`` where:
|
||||||
|
- ``cleaned_url`` is the URL without authentication credentials
|
||||||
|
- ``headers_dict`` contains Authorization header if credentials were found
|
||||||
|
"""
|
||||||
|
if not url:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
# If no authentication info, return as-is
|
||||||
|
if not parsed.username:
|
||||||
|
return url, None
|
||||||
|
|
||||||
|
# Handle case where password might be empty string or None
|
||||||
|
password = parsed.password or ""
|
||||||
|
|
||||||
|
# Extract credentials and create basic auth header (decode percent-encoding)
|
||||||
|
username = unquote(parsed.username)
|
||||||
|
password = unquote(password)
|
||||||
|
credentials = f"{username}:{password}"
|
||||||
|
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
||||||
|
headers = {"Authorization": f"Basic {encoded_credentials}"}
|
||||||
|
|
||||||
|
# Reconstruct URL without authentication
|
||||||
|
cleaned_netloc = parsed.hostname or ""
|
||||||
|
if parsed.port:
|
||||||
|
cleaned_netloc += f":{parsed.port}"
|
||||||
|
|
||||||
|
cleaned_url = f"{parsed.scheme}://{cleaned_netloc}"
|
||||||
|
if parsed.path:
|
||||||
|
cleaned_url += parsed.path
|
||||||
|
if parsed.query:
|
||||||
|
cleaned_url += f"?{parsed.query}"
|
||||||
|
if parsed.fragment:
|
||||||
|
cleaned_url += f"#{parsed.fragment}"
|
||||||
|
|
||||||
|
return cleaned_url, headers
|
||||||
|
@ -7,19 +7,10 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import (
|
from typing import Any, Callable, Literal, Optional, Union, cast
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
CallbackManagerForLLMRun,
|
|
||||||
)
|
|
||||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
|
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.language_models import LanguageModelInput
|
from langchain_core.language_models import LanguageModelInput
|
||||||
@ -57,7 +48,7 @@ from pydantic.json_schema import JsonSchemaValue
|
|||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from typing_extensions import Self, is_typeddict
|
from typing_extensions import Self, is_typeddict
|
||||||
|
|
||||||
from ._utils import validate_model
|
from ._utils import parse_url_with_auth, validate_model
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -607,6 +598,15 @@ class ChatOllama(BaseChatModel):
|
|||||||
"""Set clients to use for ollama."""
|
"""Set clients to use for ollama."""
|
||||||
client_kwargs = self.client_kwargs or {}
|
client_kwargs = self.client_kwargs or {}
|
||||||
|
|
||||||
|
# Parse URL for basic auth credentials
|
||||||
|
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
|
||||||
|
|
||||||
|
# Merge authentication headers with existing headers
|
||||||
|
if auth_headers:
|
||||||
|
headers = client_kwargs.get("headers", {})
|
||||||
|
headers.update(auth_headers)
|
||||||
|
client_kwargs = {**client_kwargs, "headers": headers}
|
||||||
|
|
||||||
sync_client_kwargs = client_kwargs
|
sync_client_kwargs = client_kwargs
|
||||||
if self.sync_client_kwargs:
|
if self.sync_client_kwargs:
|
||||||
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
|
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
|
||||||
@ -615,8 +615,8 @@ class ChatOllama(BaseChatModel):
|
|||||||
if self.async_client_kwargs:
|
if self.async_client_kwargs:
|
||||||
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
|
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
|
||||||
|
|
||||||
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
self._client = Client(host=cleaned_url, **sync_client_kwargs)
|
||||||
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
|
||||||
if self.validate_model_on_init:
|
if self.validate_model_on_init:
|
||||||
validate_model(self._client, self.model)
|
validate_model(self._client, self.model)
|
||||||
return self
|
return self
|
||||||
|
@ -6,15 +6,10 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from ollama import AsyncClient, Client
|
from ollama import AsyncClient, Client
|
||||||
from pydantic import (
|
from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator
|
||||||
BaseModel,
|
|
||||||
ConfigDict,
|
|
||||||
PrivateAttr,
|
|
||||||
model_validator,
|
|
||||||
)
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from ._utils import validate_model
|
from ._utils import parse_url_with_auth, validate_model
|
||||||
|
|
||||||
|
|
||||||
class OllamaEmbeddings(BaseModel, Embeddings):
|
class OllamaEmbeddings(BaseModel, Embeddings):
|
||||||
@ -260,6 +255,15 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Set clients to use for ollama."""
|
"""Set clients to use for ollama."""
|
||||||
client_kwargs = self.client_kwargs or {}
|
client_kwargs = self.client_kwargs or {}
|
||||||
|
|
||||||
|
# Parse URL for basic auth credentials
|
||||||
|
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
|
||||||
|
|
||||||
|
# Merge authentication headers with existing headers
|
||||||
|
if auth_headers:
|
||||||
|
headers = client_kwargs.get("headers", {})
|
||||||
|
headers.update(auth_headers)
|
||||||
|
client_kwargs = {**client_kwargs, "headers": headers}
|
||||||
|
|
||||||
sync_client_kwargs = client_kwargs
|
sync_client_kwargs = client_kwargs
|
||||||
if self.sync_client_kwargs:
|
if self.sync_client_kwargs:
|
||||||
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
|
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
|
||||||
@ -268,8 +272,8 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
if self.async_client_kwargs:
|
if self.async_client_kwargs:
|
||||||
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
|
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
|
||||||
|
|
||||||
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
self._client = Client(host=cleaned_url, **sync_client_kwargs)
|
||||||
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
|
||||||
if self.validate_model_on_init:
|
if self.validate_model_on_init:
|
||||||
validate_model(self._client, self.model)
|
validate_model(self._client, self.model)
|
||||||
return self
|
return self
|
||||||
|
@ -3,12 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||||
from typing import (
|
from typing import Any, Literal, Optional, Union
|
||||||
Any,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -20,7 +15,7 @@ from ollama import AsyncClient, Client, Options
|
|||||||
from pydantic import PrivateAttr, model_validator
|
from pydantic import PrivateAttr, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from ._utils import validate_model
|
from ._utils import parse_url_with_auth, validate_model
|
||||||
|
|
||||||
|
|
||||||
class OllamaLLM(BaseLLM):
|
class OllamaLLM(BaseLLM):
|
||||||
@ -230,6 +225,15 @@ class OllamaLLM(BaseLLM):
|
|||||||
"""Set clients to use for ollama."""
|
"""Set clients to use for ollama."""
|
||||||
client_kwargs = self.client_kwargs or {}
|
client_kwargs = self.client_kwargs or {}
|
||||||
|
|
||||||
|
# Parse URL for basic auth credentials
|
||||||
|
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
|
||||||
|
|
||||||
|
# Merge authentication headers with existing headers
|
||||||
|
if auth_headers:
|
||||||
|
headers = client_kwargs.get("headers", {})
|
||||||
|
headers.update(auth_headers)
|
||||||
|
client_kwargs = {**client_kwargs, "headers": headers}
|
||||||
|
|
||||||
sync_client_kwargs = client_kwargs
|
sync_client_kwargs = client_kwargs
|
||||||
if self.sync_client_kwargs:
|
if self.sync_client_kwargs:
|
||||||
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
|
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
|
||||||
@ -238,8 +242,8 @@ class OllamaLLM(BaseLLM):
|
|||||||
if self.async_client_kwargs:
|
if self.async_client_kwargs:
|
||||||
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
|
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
|
||||||
|
|
||||||
self._client = Client(host=self.base_url, **sync_client_kwargs)
|
self._client = Client(host=cleaned_url, **sync_client_kwargs)
|
||||||
self._async_client = AsyncClient(host=self.base_url, **async_client_kwargs)
|
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
|
||||||
if self.validate_model_on_init:
|
if self.validate_model_on_init:
|
||||||
validate_model(self._client, self.model)
|
validate_model(self._client, self.model)
|
||||||
return self
|
return self
|
||||||
|
248
libs/partners/ollama/tests/unit_tests/test_url_auth.py
Normal file
248
libs/partners/ollama/tests/unit_tests/test_url_auth.py
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
"""Test URL authentication parsing functionality."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from langchain_ollama._utils import parse_url_with_auth
|
||||||
|
from langchain_ollama.chat_models import ChatOllama
|
||||||
|
from langchain_ollama.embeddings import OllamaEmbeddings
|
||||||
|
from langchain_ollama.llms import OllamaLLM
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseUrlWithAuth:
|
||||||
|
"""Test the parse_url_with_auth utility function."""
|
||||||
|
|
||||||
|
def test_parse_url_with_auth_none_input(self) -> None:
|
||||||
|
"""Test that None input returns None, None."""
|
||||||
|
result = parse_url_with_auth(None)
|
||||||
|
assert result == (None, None)
|
||||||
|
|
||||||
|
def test_parse_url_with_auth_no_credentials(self) -> None:
|
||||||
|
"""Test URLs without authentication credentials."""
|
||||||
|
url = "https://ollama.example.com:11434"
|
||||||
|
result = parse_url_with_auth(url)
|
||||||
|
assert result == (url, None)
|
||||||
|
|
||||||
|
def test_parse_url_with_auth_with_credentials(self) -> None:
|
||||||
|
"""Test URLs with authentication credentials."""
|
||||||
|
url = "https://user:password@ollama.example.com:11434"
|
||||||
|
cleaned_url, headers = parse_url_with_auth(url)
|
||||||
|
|
||||||
|
expected_url = "https://ollama.example.com:11434"
|
||||||
|
expected_credentials = base64.b64encode(b"user:password").decode()
|
||||||
|
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||||
|
|
||||||
|
assert cleaned_url == expected_url
|
||||||
|
assert headers == expected_headers
|
||||||
|
|
||||||
|
def test_parse_url_with_auth_with_path_and_query(self) -> None:
|
||||||
|
"""Test URLs with path and query parameters."""
|
||||||
|
url = "https://user:pass@ollama.example.com:11434/api/v1?timeout=30"
|
||||||
|
cleaned_url, headers = parse_url_with_auth(url)
|
||||||
|
|
||||||
|
expected_url = "https://ollama.example.com:11434/api/v1?timeout=30"
|
||||||
|
expected_credentials = base64.b64encode(b"user:pass").decode()
|
||||||
|
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||||
|
|
||||||
|
assert cleaned_url == expected_url
|
||||||
|
assert headers == expected_headers
|
||||||
|
|
||||||
|
def test_parse_url_with_auth_special_characters(self) -> None:
|
||||||
|
"""Test URLs with special characters in credentials."""
|
||||||
|
url = "https://user%40domain:p%40ssw0rd@ollama.example.com:11434"
|
||||||
|
cleaned_url, headers = parse_url_with_auth(url)
|
||||||
|
|
||||||
|
expected_url = "https://ollama.example.com:11434"
|
||||||
|
# Note: URL parsing handles percent-encoding automatically
|
||||||
|
expected_credentials = base64.b64encode(b"user@domain:p@ssw0rd").decode()
|
||||||
|
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||||
|
|
||||||
|
assert cleaned_url == expected_url
|
||||||
|
assert headers == expected_headers
|
||||||
|
|
||||||
|
def test_parse_url_with_auth_only_username(self) -> None:
|
||||||
|
"""Test URLs with only username (no password)."""
|
||||||
|
url = "https://user@ollama.example.com:11434"
|
||||||
|
cleaned_url, headers = parse_url_with_auth(url)
|
||||||
|
|
||||||
|
expected_url = "https://ollama.example.com:11434"
|
||||||
|
expected_credentials = base64.b64encode(b"user:").decode()
|
||||||
|
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||||
|
|
||||||
|
assert cleaned_url == expected_url
|
||||||
|
assert headers == expected_headers
|
||||||
|
|
||||||
|
def test_parse_url_with_auth_empty_password(self) -> None:
|
||||||
|
"""Test URLs with empty password."""
|
||||||
|
url = "https://user:@ollama.example.com:11434"
|
||||||
|
cleaned_url, headers = parse_url_with_auth(url)
|
||||||
|
|
||||||
|
expected_url = "https://ollama.example.com:11434"
|
||||||
|
expected_credentials = base64.b64encode(b"user:").decode()
|
||||||
|
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||||
|
|
||||||
|
assert cleaned_url == expected_url
|
||||||
|
assert headers == expected_headers
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatOllamaUrlAuth:
|
||||||
|
"""Test URL authentication integration with ChatOllama."""
|
||||||
|
|
||||||
|
@patch("langchain_ollama.chat_models.Client")
|
||||||
|
@patch("langchain_ollama.chat_models.AsyncClient")
|
||||||
|
def test_chat_ollama_url_auth_integration(
|
||||||
|
self, mock_async_client: MagicMock, mock_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that ChatOllama properly handles URL authentication."""
|
||||||
|
url_with_auth = "https://user:password@ollama.example.com:11434"
|
||||||
|
|
||||||
|
ChatOllama(
|
||||||
|
model="llama3",
|
||||||
|
base_url=url_with_auth,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the clients were called with cleaned URL and auth headers
|
||||||
|
expected_url = "https://ollama.example.com:11434"
|
||||||
|
expected_credentials = base64.b64encode(b"user:password").decode()
|
||||||
|
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||||
|
|
||||||
|
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
|
||||||
|
mock_async_client.assert_called_once_with(
|
||||||
|
host=expected_url, headers=expected_headers
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("langchain_ollama.chat_models.Client")
|
||||||
|
@patch("langchain_ollama.chat_models.AsyncClient")
|
||||||
|
def test_chat_ollama_url_auth_with_existing_headers(
|
||||||
|
self, mock_async_client: MagicMock, mock_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that URL auth headers merge with existing headers."""
|
||||||
|
url_with_auth = "https://user:password@ollama.example.com:11434"
|
||||||
|
existing_headers = {"User-Agent": "test-agent", "X-Custom": "value"}
|
||||||
|
|
||||||
|
ChatOllama(
|
||||||
|
model="llama3",
|
||||||
|
base_url=url_with_auth,
|
||||||
|
client_kwargs={"headers": existing_headers},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify headers are merged
|
||||||
|
expected_url = "https://ollama.example.com:11434"
|
||||||
|
expected_credentials = base64.b64encode(b"user:password").decode()
|
||||||
|
expected_headers = {
|
||||||
|
**existing_headers,
|
||||||
|
"Authorization": f"Basic {expected_credentials}",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
|
||||||
|
mock_async_client.assert_called_once_with(
|
||||||
|
host=expected_url, headers=expected_headers
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("langchain_ollama.chat_models.Client")
|
||||||
|
@patch("langchain_ollama.chat_models.AsyncClient")
|
||||||
|
def test_chat_ollama_no_url_auth(
|
||||||
|
self, mock_async_client: MagicMock, mock_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that ChatOllama works normally without URL authentication."""
|
||||||
|
url_without_auth = "https://ollama.example.com:11434"
|
||||||
|
|
||||||
|
ChatOllama(
|
||||||
|
model="llama3",
|
||||||
|
base_url=url_without_auth,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify no auth headers are added
|
||||||
|
mock_client.assert_called_once_with(host=url_without_auth)
|
||||||
|
mock_async_client.assert_called_once_with(host=url_without_auth)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOllamaLLMUrlAuth:
|
||||||
|
"""Test URL authentication integration with OllamaLLM."""
|
||||||
|
|
||||||
|
@patch("langchain_ollama.llms.Client")
|
||||||
|
@patch("langchain_ollama.llms.AsyncClient")
|
||||||
|
def test_ollama_llm_url_auth_integration(
|
||||||
|
self, mock_async_client: MagicMock, mock_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that OllamaLLM properly handles URL authentication."""
|
||||||
|
url_with_auth = "https://user:password@ollama.example.com:11434"
|
||||||
|
|
||||||
|
OllamaLLM(
|
||||||
|
model="llama3",
|
||||||
|
base_url=url_with_auth,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the clients were called with cleaned URL and auth headers
|
||||||
|
expected_url = "https://ollama.example.com:11434"
|
||||||
|
expected_credentials = base64.b64encode(b"user:password").decode()
|
||||||
|
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||||
|
|
||||||
|
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
|
||||||
|
mock_async_client.assert_called_once_with(
|
||||||
|
host=expected_url, headers=expected_headers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOllamaEmbeddingsUrlAuth:
|
||||||
|
"""Test URL authentication integration with OllamaEmbeddings."""
|
||||||
|
|
||||||
|
@patch("langchain_ollama.embeddings.Client")
|
||||||
|
@patch("langchain_ollama.embeddings.AsyncClient")
|
||||||
|
def test_ollama_embeddings_url_auth_integration(
|
||||||
|
self, mock_async_client: MagicMock, mock_client: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Test that OllamaEmbeddings properly handles URL authentication."""
|
||||||
|
url_with_auth = "https://user:password@ollama.example.com:11434"
|
||||||
|
|
||||||
|
OllamaEmbeddings(
|
||||||
|
model="llama3",
|
||||||
|
base_url=url_with_auth,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the clients were called with cleaned URL and auth headers
|
||||||
|
expected_url = "https://ollama.example.com:11434"
|
||||||
|
expected_credentials = base64.b64encode(b"user:password").decode()
|
||||||
|
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||||
|
|
||||||
|
mock_client.assert_called_once_with(host=expected_url, headers=expected_headers)
|
||||||
|
mock_async_client.assert_called_once_with(
|
||||||
|
host=expected_url, headers=expected_headers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUrlAuthEdgeCases:
|
||||||
|
"""Test edge cases and error conditions for URL authentication."""
|
||||||
|
|
||||||
|
def test_parse_url_with_auth_malformed_url(self) -> None:
|
||||||
|
"""Test behavior with malformed URLs."""
|
||||||
|
malformed_url = "not-a-valid-url"
|
||||||
|
result = parse_url_with_auth(malformed_url)
|
||||||
|
# Should return the original URL without modification
|
||||||
|
assert result == (malformed_url, None)
|
||||||
|
|
||||||
|
def test_parse_url_with_auth_no_port(self) -> None:
|
||||||
|
"""Test URLs without explicit port numbers."""
|
||||||
|
url = "https://user:password@ollama.example.com"
|
||||||
|
cleaned_url, headers = parse_url_with_auth(url)
|
||||||
|
|
||||||
|
expected_url = "https://ollama.example.com"
|
||||||
|
expected_credentials = base64.b64encode(b"user:password").decode()
|
||||||
|
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||||
|
|
||||||
|
assert cleaned_url == expected_url
|
||||||
|
assert headers == expected_headers
|
||||||
|
|
||||||
|
def test_parse_url_with_auth_complex_password(self) -> None:
|
||||||
|
"""Test with complex passwords containing special characters."""
|
||||||
|
# Test password with colon, which is the delimiter
|
||||||
|
url = "https://user:pass:word@ollama.example.com:11434"
|
||||||
|
cleaned_url, headers = parse_url_with_auth(url)
|
||||||
|
|
||||||
|
expected_url = "https://ollama.example.com:11434"
|
||||||
|
# The parser should handle the first colon as the separator
|
||||||
|
expected_credentials = base64.b64encode(b"user:pass:word").decode()
|
||||||
|
expected_headers = {"Authorization": f"Basic {expected_credentials}"}
|
||||||
|
|
||||||
|
assert cleaned_url == expected_url
|
||||||
|
assert headers == expected_headers
|
@ -305,7 +305,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.3.70"
|
version = "0.3.72"
|
||||||
source = { editable = "../../core" }
|
source = { editable = "../../core" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "jsonpatch" },
|
{ name = "jsonpatch" },
|
||||||
|
Loading…
Reference in New Issue
Block a user