langchain/libs/community/tests/unit_tests/llms/test_ollama.py
rick-SOPTIM cd563fb628
community[minor]: passthrough auth parameter on requests to Ollama-LLMs (#24068)
Thank you for contributing to LangChain!

**Description:**
This PR allows users of `langchain_community.llms.ollama.Ollama` to
specify the `auth` parameter, which is then forwarded to all internal
calls of `requests.request`. This works in the same way as the existing
`headers` parameters. The auth parameter enables the usage of the given
class with Ollama instances, which are secured by more complex
authentication mechanisms, that do not only rely on static headers. An
example are AWS API Gateways secured by the IAM authorizer, which
expects signatures dynamically calculated on the specific HTTP request.

**Issue:**

Integrating a remote LLM running through Ollama using
`langchain_community.llms.ollama.Ollama` only allows setting static HTTP
headers with the parameter `headers`. This does not work, if the given
instance of Ollama is secured with an authentication mechanism that
makes use of dynamically created HTTP headers which for example may
depend on the content of a given request.

**Dependencies:**

None

**Twitter handle:**

None

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-07-25 15:48:35 +00:00

227 lines
7.1 KiB
Python

import requests
from pytest import MonkeyPatch
from langchain_community.llms.ollama import Ollama
def mock_response_stream(): # type: ignore[no-untyped-def]
mock_response = [b'{ "response": "Response chunk 1" }']
class MockRaw:
def read(self, chunk_size): # type: ignore[no-untyped-def]
try:
return mock_response.pop()
except IndexError:
return None
response = requests.Response()
response.status_code = 200
response.raw = MockRaw()
return response
def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(
base_url="https://ollama-hostname:8000",
model="foo",
headers={
"Authorization": "Bearer TEST-TOKEN-VALUE",
"Referer": "https://application-host",
},
timeout=300,
)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
"Authorization": "Bearer TEST-TOKEN-VALUE",
"Referer": "https://application-host",
}
assert json is not None
assert stream is True
assert timeout == 300
return mock_response_stream()
monkeypatch.setattr(requests, "post", mock_post)
llm.invoke("Test prompt")
def test_pass_auth_if_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(
base_url="https://ollama-hostname:8000",
model="foo",
auth=("Test-User", "Test-Password"),
timeout=300,
)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
}
assert json is not None
assert stream is True
assert timeout == 300
assert auth == ("Test-User", "Test-Password")
return mock_response_stream()
monkeypatch.setattr(requests, "post", mock_post)
llm.invoke("Test prompt")
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
}
assert json is not None
assert stream is True
assert timeout == 300
return mock_response_stream()
monkeypatch.setattr(requests, "post", mock_post)
llm.invoke("Test prompt")
def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
"""Test that top level params are sent to the endpoint as top level params"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
}
assert json == {
"format": None,
"images": None,
"model": "test-model",
"options": {
"mirostat": None,
"mirostat_eta": None,
"mirostat_tau": None,
"num_ctx": None,
"num_gpu": None,
"num_thread": None,
"num_predict": None,
"repeat_last_n": None,
"repeat_penalty": None,
"stop": None,
"temperature": None,
"tfs_z": None,
"top_k": None,
"top_p": None,
},
"prompt": "Test prompt",
"system": "Test system prompt",
"template": None,
"keep_alive": None,
"raw": None,
}
assert stream is True
assert timeout == 300
return mock_response_stream()
monkeypatch.setattr(requests, "post", mock_post)
llm.invoke("Test prompt", model="test-model", system="Test system prompt")
def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
"""
Test that params that are not top level params will be sent to the endpoint
as options
"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
}
assert json == {
"format": None,
"images": None,
"model": "foo",
"options": {
"mirostat": None,
"mirostat_eta": None,
"mirostat_tau": None,
"num_ctx": None,
"num_gpu": None,
"num_thread": None,
"num_predict": None,
"repeat_last_n": None,
"repeat_penalty": None,
"stop": None,
"temperature": 0.8,
"tfs_z": None,
"top_k": None,
"top_p": None,
"unknown": "Unknown parameter value",
},
"prompt": "Test prompt",
"system": None,
"template": None,
"keep_alive": None,
"raw": None,
}
assert stream is True
assert timeout == 300
return mock_response_stream()
monkeypatch.setattr(requests, "post", mock_post)
llm.invoke("Test prompt", unknown="Unknown parameter value", temperature=0.8)
def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
"""
Test that if options provided it will be sent to the endpoint as options,
ignoring other params that are not top level params.
"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
}
assert json == {
"format": None,
"images": None,
"model": "test-another-model",
"options": {"unknown_option": "Unknown option value"},
"prompt": "Test prompt",
"system": None,
"template": None,
"keep_alive": None,
"raw": None,
}
assert stream is True
assert timeout == 300
return mock_response_stream()
monkeypatch.setattr(requests, "post", mock_post)
llm.invoke(
"Test prompt",
model="test-another-model",
options={"unknown_option": "Unknown option value"},
unknown="Unknown parameter value",
temperature=0.8,
)