community: Ollama - Pass headers to post request (#15881)

## Feature
- Set additional headers in constructor
- Headers will be sent in post request

This feature is useful if deploying Ollama on a cloud service such as
hugging face, which requires authentication tokens to be passed in the
request header.

## Tests
- Test if header is passed
- Test if header is not passed
This commit is contained in:
shahrin014 2024-01-12 14:40:35 +09:00 committed by GitHub
parent 5efec068c9
commit bdd90ae2ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 78 additions and 1 deletions

View File

@ -107,6 +107,12 @@ class _OllamaCommon(BaseLanguageModel):
timeout: Optional[int] = None
"""Timeout for the request stream"""
headers: Optional[dict] = None
"""Additional headers to pass to endpoint (e.g. Authorization, Referer).
This is useful when Ollama is hosted on cloud services that require
tokens for authentication.
"""
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama."""
@ -207,7 +213,10 @@ class _OllamaCommon(BaseLanguageModel):
response = requests.post(
url=api_url,
headers={"Content-Type": "application/json"},
headers={
"Content-Type": "application/json",
**(self.headers if isinstance(self.headers, dict) else {}),
},
json=request_payload,
stream=True,
timeout=self.timeout,

View File

@ -0,0 +1,68 @@
import requests
from pytest import MonkeyPatch
from langchain_community.llms.ollama import Ollama
def mock_response_stream():
mock_response = [b'{ "response": "Response chunk 1" }']
class MockRaw:
def read(self, chunk_size):
try:
return mock_response.pop()
except IndexError:
return None
response = requests.Response()
response.status_code = 200
response.raw = MockRaw()
return response
def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(
base_url="https://ollama-hostname:8000",
model="foo",
headers={
"Authentication": "Bearer TEST-TOKEN-VALUE",
"Referer": "https://application-host",
},
timeout=300,
)
def mockPost(url, headers, json, stream, timeout):
assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == {
"Content-Type": "application/json",
"Authentication": "Bearer TEST-TOKEN-VALUE",
"Referer": "https://application-host",
}
assert json is not None
assert stream is True
assert timeout == 300
return mock_response_stream()
monkeypatch.setattr(requests, "post", mockPost)
llm("Test prompt")
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mockPost(url, headers, json, stream, timeout):
assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == {
"Content-Type": "application/json",
}
assert json is not None
assert stream is True
assert timeout == 300
return mock_response_stream()
monkeypatch.setattr(requests, "post", mockPost)
llm("Test prompt")