From f51e6a35ba869caa714e25967ab77d9993fbc7b5 Mon Sep 17 00:00:00 2001 From: shahrin014 Date: Sat, 30 Mar 2024 03:44:52 +0900 Subject: [PATCH] community[patch]: OllamaEmbeddings - Pass headers to post request (#16880) ## Feature - Set additional headers in constructor - Headers will be sent in post request This feature is useful if deploying Ollama on a cloud service such as hugging face, which requires authentication tokens to be passed in the request header. ## Tests - Test if header is passed - Test if header is not passed Similar to https://github.com/langchain-ai/langchain/pull/15881 --------- Co-authored-by: Bagatur --- .../langchain_community/embeddings/ollama.py | 7 +++ .../unit_tests/embeddings/test_ollama.py | 61 +++++++++++++++++++ .../tests/unit_tests/llms/test_ollama.py | 4 +- 3 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 libs/community/tests/unit_tests/embeddings/test_ollama.py diff --git a/libs/community/langchain_community/embeddings/ollama.py b/libs/community/langchain_community/embeddings/ollama.py index f1c28f1124e..b1fd668b42a 100644 --- a/libs/community/langchain_community/embeddings/ollama.py +++ b/libs/community/langchain_community/embeddings/ollama.py @@ -105,6 +105,12 @@ class OllamaEmbeddings(BaseModel, Embeddings): show_progress: bool = False """Whether to show a tqdm progress bar. Must have `tqdm` installed.""" + headers: Optional[dict] = None + """Additional headers to pass to endpoint (e.g. Authorization, Referer). + This is useful when Ollama is hosted on cloud services that require + tokens for authentication. + """ + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling Ollama.""" @@ -151,6 +157,7 @@ class OllamaEmbeddings(BaseModel, Embeddings): """ headers = { "Content-Type": "application/json", + **(self.headers or {}), } try: diff --git a/libs/community/tests/unit_tests/embeddings/test_ollama.py b/libs/community/tests/unit_tests/embeddings/test_ollama.py new file mode 100644 index 00000000000..aea354eca8c --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_ollama.py @@ -0,0 +1,61 @@ +import requests +from pytest import MonkeyPatch + +from langchain_community.embeddings.ollama import OllamaEmbeddings + + +class MockResponse: + status_code = 200 + + def json(self) -> dict: + return {"embedding": [1, 2, 3]} + + +def mock_response() -> MockResponse: + return MockResponse() + + +def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: + embedder = OllamaEmbeddings( + base_url="https://ollama-hostname:8000", + model="foo", + headers={ + "Authorization": "Bearer TEST-TOKEN-VALUE", + "Referer": "https://application-host", + }, + ) + + def mock_post(url: str, headers: dict, json: str) -> MockResponse: + assert url == "https://ollama-hostname:8000/api/embeddings" + assert headers == { + "Content-Type": "application/json", + "Authorization": "Bearer TEST-TOKEN-VALUE", + "Referer": "https://application-host", + } + assert json is not None + + return mock_response() + + monkeypatch.setattr(requests, "post", mock_post) + + embedder.embed_query("Test prompt") + + +def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None: + embedder = OllamaEmbeddings( + base_url="https://ollama-hostname:8000", + model="foo", + ) + + def mock_post(url: str, headers: dict, json: str) -> MockResponse: + assert url == "https://ollama-hostname:8000/api/embeddings" + assert headers == { + "Content-Type": "application/json", + } + assert json is not None + + return mock_response() + + monkeypatch.setattr(requests, "post", mock_post) + + embedder.embed_query("Test prompt") diff --git a/libs/community/tests/unit_tests/llms/test_ollama.py b/libs/community/tests/unit_tests/llms/test_ollama.py index 8807aab826c..2e88defe6b1 100644 --- a/libs/community/tests/unit_tests/llms/test_ollama.py +++ b/libs/community/tests/unit_tests/llms/test_ollama.py @@ -25,7 +25,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: base_url="https://ollama-hostname:8000", model="foo", headers={ - "Authentication": "Bearer TEST-TOKEN-VALUE", + "Authorization": "Bearer TEST-TOKEN-VALUE", "Referer": "https://application-host", }, timeout=300, @@ -35,7 +35,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: assert url == "https://ollama-hostname:8000/api/generate" assert headers == { "Content-Type": "application/json", - "Authentication": "Bearer TEST-TOKEN-VALUE", + "Authorization": "Bearer TEST-TOKEN-VALUE", "Referer": "https://application-host", } assert json is not None