diff --git a/libs/community/langchain_community/llms/ollama.py b/libs/community/langchain_community/llms/ollama.py index 022422417ed..4e7f1838f75 100644 --- a/libs/community/langchain_community/llms/ollama.py +++ b/libs/community/langchain_community/llms/ollama.py @@ -107,6 +107,12 @@ class _OllamaCommon(BaseLanguageModel): timeout: Optional[int] = None """Timeout for the request stream""" + headers: Optional[dict] = None + """Additional headers to pass to endpoint (e.g. Authorization, Referer). + This is useful when Ollama is hosted on cloud services that require + tokens for authentication. + """ + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling Ollama.""" @@ -207,7 +213,10 @@ class _OllamaCommon(BaseLanguageModel): response = requests.post( url=api_url, - headers={"Content-Type": "application/json"}, + headers={ + "Content-Type": "application/json", + **(self.headers if isinstance(self.headers, dict) else {}), + }, json=request_payload, stream=True, timeout=self.timeout, diff --git a/libs/community/tests/unit_tests/llms/test_ollama.py b/libs/community/tests/unit_tests/llms/test_ollama.py new file mode 100644 index 00000000000..ff18daacfa6 --- /dev/null +++ b/libs/community/tests/unit_tests/llms/test_ollama.py @@ -0,0 +1,68 @@ +import requests +from pytest import MonkeyPatch + +from langchain_community.llms.ollama import Ollama + + +def mock_response_stream(): + mock_response = [b'{ "response": "Response chunk 1" }'] + + class MockRaw: + def read(self, chunk_size): + try: + return mock_response.pop() + except IndexError: + return None + + response = requests.Response() + response.status_code = 200 + response.raw = MockRaw() + return response + + +def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: + llm = Ollama( + base_url="https://ollama-hostname:8000", + model="foo", + headers={ + "Authentication": "Bearer TEST-TOKEN-VALUE", + "Referer": "https://application-host", + }, + timeout=300, + ) + + def mockPost(url, headers, json, stream, timeout): + assert url == "https://ollama-hostname:8000/api/generate/" + assert headers == { + "Content-Type": "application/json", + "Authentication": "Bearer TEST-TOKEN-VALUE", + "Referer": "https://application-host", + } + assert json is not None + assert stream is True + assert timeout == 300 + + return mock_response_stream() + + monkeypatch.setattr(requests, "post", mockPost) + + llm("Test prompt") + + +def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None: + llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) + + def mockPost(url, headers, json, stream, timeout): + assert url == "https://ollama-hostname:8000/api/generate/" + assert headers == { + "Content-Type": "application/json", + } + assert json is not None + assert stream is True + assert timeout == 300 + + return mock_response_stream() + + monkeypatch.setattr(requests, "post", mockPost) + + llm("Test prompt")