mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
community[patch]: OllamaEmbeddings - Pass headers to post request (#16880)
## Feature - Set additional headers in constructor - Headers will be sent in post request This feature is useful if deploying Ollama on a cloud service such as hugging face, which requires authentication tokens to be passed in the request header. ## Tests - Test if header is passed - Test if header is not passed Similar to https://github.com/langchain-ai/langchain/pull/15881 --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
e0f137dbe0
commit
f51e6a35ba
@ -105,6 +105,12 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
show_progress: bool = False
|
show_progress: bool = False
|
||||||
"""Whether to show a tqdm progress bar. Must have `tqdm` installed."""
|
"""Whether to show a tqdm progress bar. Must have `tqdm` installed."""
|
||||||
|
|
||||||
|
headers: Optional[dict] = None
|
||||||
|
"""Additional headers to pass to endpoint (e.g. Authorization, Referer).
|
||||||
|
This is useful when Ollama is hosted on cloud services that require
|
||||||
|
tokens for authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
"""Get the default parameters for calling Ollama."""
|
"""Get the default parameters for calling Ollama."""
|
||||||
@ -151,6 +157,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
**(self.headers or {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
61
libs/community/tests/unit_tests/embeddings/test_ollama.py
Normal file
61
libs/community/tests/unit_tests/embeddings/test_ollama.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import requests
|
||||||
|
from pytest import MonkeyPatch
|
||||||
|
|
||||||
|
from langchain_community.embeddings.ollama import OllamaEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
status_code = 200
|
||||||
|
|
||||||
|
def json(self) -> dict:
|
||||||
|
return {"embedding": [1, 2, 3]}
|
||||||
|
|
||||||
|
|
||||||
|
def mock_response() -> MockResponse:
|
||||||
|
return MockResponse()
|
||||||
|
|
||||||
|
|
||||||
|
def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
|
||||||
|
embedder = OllamaEmbeddings(
|
||||||
|
base_url="https://ollama-hostname:8000",
|
||||||
|
model="foo",
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer TEST-TOKEN-VALUE",
|
||||||
|
"Referer": "https://application-host",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_post(url: str, headers: dict, json: str) -> MockResponse:
|
||||||
|
assert url == "https://ollama-hostname:8000/api/embeddings"
|
||||||
|
assert headers == {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": "Bearer TEST-TOKEN-VALUE",
|
||||||
|
"Referer": "https://application-host",
|
||||||
|
}
|
||||||
|
assert json is not None
|
||||||
|
|
||||||
|
return mock_response()
|
||||||
|
|
||||||
|
monkeypatch.setattr(requests, "post", mock_post)
|
||||||
|
|
||||||
|
embedder.embed_query("Test prompt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
|
||||||
|
embedder = OllamaEmbeddings(
|
||||||
|
base_url="https://ollama-hostname:8000",
|
||||||
|
model="foo",
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_post(url: str, headers: dict, json: str) -> MockResponse:
|
||||||
|
assert url == "https://ollama-hostname:8000/api/embeddings"
|
||||||
|
assert headers == {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
assert json is not None
|
||||||
|
|
||||||
|
return mock_response()
|
||||||
|
|
||||||
|
monkeypatch.setattr(requests, "post", mock_post)
|
||||||
|
|
||||||
|
embedder.embed_query("Test prompt")
|
@ -25,7 +25,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
|
|||||||
base_url="https://ollama-hostname:8000",
|
base_url="https://ollama-hostname:8000",
|
||||||
model="foo",
|
model="foo",
|
||||||
headers={
|
headers={
|
||||||
"Authentication": "Bearer TEST-TOKEN-VALUE",
|
"Authorization": "Bearer TEST-TOKEN-VALUE",
|
||||||
"Referer": "https://application-host",
|
"Referer": "https://application-host",
|
||||||
},
|
},
|
||||||
timeout=300,
|
timeout=300,
|
||||||
@ -35,7 +35,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
|
|||||||
assert url == "https://ollama-hostname:8000/api/generate"
|
assert url == "https://ollama-hostname:8000/api/generate"
|
||||||
assert headers == {
|
assert headers == {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authentication": "Bearer TEST-TOKEN-VALUE",
|
"Authorization": "Bearer TEST-TOKEN-VALUE",
|
||||||
"Referer": "https://application-host",
|
"Referer": "https://application-host",
|
||||||
}
|
}
|
||||||
assert json is not None
|
assert json is not None
|
||||||
|
Loading…
Reference in New Issue
Block a user