mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-29 04:16:02 +00:00
Thank you for contributing to LangChain! **Description:** This PR allows users of `langchain_community.llms.ollama.Ollama` to specify the `auth` parameter, which is then forwarded to all internal calls of `requests.request`. This works in the same way as the existing `headers` parameters. The auth parameter enables the usage of the given class with Ollama instances, which are secured by more complex authentication mechanisms, that do not only rely on static headers. An example are AWS API Gateways secured by the IAM authorizer, which expects signatures dynamically calculated on the specific HTTP request. **Issue:** Integrating a remote LLM running through Ollama using `langchain_community.llms.ollama.Ollama` only allows setting static HTTP headers with the parameter `headers`. This does not work, if the given instance of Ollama is secured with an authentication mechanism that makes use of dynamically created HTTP headers which for example may depend on the content of a given request. **Dependencies:** None **Twitter handle:** None --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
227 lines
7.1 KiB
Python
227 lines
7.1 KiB
Python
import requests
|
|
from pytest import MonkeyPatch
|
|
|
|
from langchain_community.llms.ollama import Ollama
|
|
|
|
|
|
def mock_response_stream(): # type: ignore[no-untyped-def]
|
|
mock_response = [b'{ "response": "Response chunk 1" }']
|
|
|
|
class MockRaw:
|
|
def read(self, chunk_size): # type: ignore[no-untyped-def]
|
|
try:
|
|
return mock_response.pop()
|
|
except IndexError:
|
|
return None
|
|
|
|
response = requests.Response()
|
|
response.status_code = 200
|
|
response.raw = MockRaw()
|
|
return response
|
|
|
|
|
|
def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
|
|
llm = Ollama(
|
|
base_url="https://ollama-hostname:8000",
|
|
model="foo",
|
|
headers={
|
|
"Authorization": "Bearer TEST-TOKEN-VALUE",
|
|
"Referer": "https://application-host",
|
|
},
|
|
timeout=300,
|
|
)
|
|
|
|
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
|
|
assert url == "https://ollama-hostname:8000/api/generate"
|
|
assert headers == {
|
|
"Content-Type": "application/json",
|
|
"Authorization": "Bearer TEST-TOKEN-VALUE",
|
|
"Referer": "https://application-host",
|
|
}
|
|
assert json is not None
|
|
assert stream is True
|
|
assert timeout == 300
|
|
|
|
return mock_response_stream()
|
|
|
|
monkeypatch.setattr(requests, "post", mock_post)
|
|
|
|
llm.invoke("Test prompt")
|
|
|
|
|
|
def test_pass_auth_if_provided(monkeypatch: MonkeyPatch) -> None:
|
|
llm = Ollama(
|
|
base_url="https://ollama-hostname:8000",
|
|
model="foo",
|
|
auth=("Test-User", "Test-Password"),
|
|
timeout=300,
|
|
)
|
|
|
|
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
|
|
assert url == "https://ollama-hostname:8000/api/generate"
|
|
assert headers == {
|
|
"Content-Type": "application/json",
|
|
}
|
|
assert json is not None
|
|
assert stream is True
|
|
assert timeout == 300
|
|
assert auth == ("Test-User", "Test-Password")
|
|
|
|
return mock_response_stream()
|
|
|
|
monkeypatch.setattr(requests, "post", mock_post)
|
|
|
|
llm.invoke("Test prompt")
|
|
|
|
|
|
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
|
|
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
|
|
|
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
|
|
assert url == "https://ollama-hostname:8000/api/generate"
|
|
assert headers == {
|
|
"Content-Type": "application/json",
|
|
}
|
|
assert json is not None
|
|
assert stream is True
|
|
assert timeout == 300
|
|
|
|
return mock_response_stream()
|
|
|
|
monkeypatch.setattr(requests, "post", mock_post)
|
|
|
|
llm.invoke("Test prompt")
|
|
|
|
|
|
def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
|
|
"""Test that top level params are sent to the endpoint as top level params"""
|
|
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
|
|
|
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
|
|
assert url == "https://ollama-hostname:8000/api/generate"
|
|
assert headers == {
|
|
"Content-Type": "application/json",
|
|
}
|
|
assert json == {
|
|
"format": None,
|
|
"images": None,
|
|
"model": "test-model",
|
|
"options": {
|
|
"mirostat": None,
|
|
"mirostat_eta": None,
|
|
"mirostat_tau": None,
|
|
"num_ctx": None,
|
|
"num_gpu": None,
|
|
"num_thread": None,
|
|
"num_predict": None,
|
|
"repeat_last_n": None,
|
|
"repeat_penalty": None,
|
|
"stop": None,
|
|
"temperature": None,
|
|
"tfs_z": None,
|
|
"top_k": None,
|
|
"top_p": None,
|
|
},
|
|
"prompt": "Test prompt",
|
|
"system": "Test system prompt",
|
|
"template": None,
|
|
"keep_alive": None,
|
|
"raw": None,
|
|
}
|
|
assert stream is True
|
|
assert timeout == 300
|
|
|
|
return mock_response_stream()
|
|
|
|
monkeypatch.setattr(requests, "post", mock_post)
|
|
|
|
llm.invoke("Test prompt", model="test-model", system="Test system prompt")
|
|
|
|
|
|
def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
|
|
"""
|
|
Test that params that are not top level params will be sent to the endpoint
|
|
as options
|
|
"""
|
|
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
|
|
|
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
|
|
assert url == "https://ollama-hostname:8000/api/generate"
|
|
assert headers == {
|
|
"Content-Type": "application/json",
|
|
}
|
|
assert json == {
|
|
"format": None,
|
|
"images": None,
|
|
"model": "foo",
|
|
"options": {
|
|
"mirostat": None,
|
|
"mirostat_eta": None,
|
|
"mirostat_tau": None,
|
|
"num_ctx": None,
|
|
"num_gpu": None,
|
|
"num_thread": None,
|
|
"num_predict": None,
|
|
"repeat_last_n": None,
|
|
"repeat_penalty": None,
|
|
"stop": None,
|
|
"temperature": 0.8,
|
|
"tfs_z": None,
|
|
"top_k": None,
|
|
"top_p": None,
|
|
"unknown": "Unknown parameter value",
|
|
},
|
|
"prompt": "Test prompt",
|
|
"system": None,
|
|
"template": None,
|
|
"keep_alive": None,
|
|
"raw": None,
|
|
}
|
|
assert stream is True
|
|
assert timeout == 300
|
|
|
|
return mock_response_stream()
|
|
|
|
monkeypatch.setattr(requests, "post", mock_post)
|
|
|
|
llm.invoke("Test prompt", unknown="Unknown parameter value", temperature=0.8)
|
|
|
|
|
|
def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
|
|
"""
|
|
Test that if options provided it will be sent to the endpoint as options,
|
|
ignoring other params that are not top level params.
|
|
"""
|
|
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
|
|
|
|
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
|
|
assert url == "https://ollama-hostname:8000/api/generate"
|
|
assert headers == {
|
|
"Content-Type": "application/json",
|
|
}
|
|
assert json == {
|
|
"format": None,
|
|
"images": None,
|
|
"model": "test-another-model",
|
|
"options": {"unknown_option": "Unknown option value"},
|
|
"prompt": "Test prompt",
|
|
"system": None,
|
|
"template": None,
|
|
"keep_alive": None,
|
|
"raw": None,
|
|
}
|
|
assert stream is True
|
|
assert timeout == 300
|
|
|
|
return mock_response_stream()
|
|
|
|
monkeypatch.setattr(requests, "post", mock_post)
|
|
|
|
llm.invoke(
|
|
"Test prompt",
|
|
model="test-another-model",
|
|
options={"unknown_option": "Unknown option value"},
|
|
unknown="Unknown parameter value",
|
|
temperature=0.8,
|
|
)
|