community: Ollama - Parameter structure to follow official documentation (#16035)

## Feature
- Follow parameter structure as per official documentation 
- top level parameters (e.g. model, system, template) will be passed as
top level parameters
  - other parameters will be sent in options unless options is provided

![image](https://github.com/langchain-ai/langchain/assets/17451563/d14715d9-9701-4ee3-b44b-89fffea62389)

## Tests
- Test if top level parameters handled properly
- Test if parameters that are not top level parameters are handled as
options
- Test if options is provided, it will be passed as is
This commit is contained in:
shahrin014 2024-01-16 03:17:58 +09:00 committed by GitHub
parent 60d6a416e6
commit 86321a949f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 137 additions and 10 deletions

View File

@ -190,8 +190,9 @@ class _OllamaCommon(BaseLanguageModel):
params = self._default_params params = self._default_params
if "model" in kwargs: for key in self._default_params:
params["model"] = kwargs["model"] if key in kwargs:
params[key] = kwargs[key]
if "options" in kwargs: if "options" in kwargs:
params["options"] = kwargs["options"] params["options"] = kwargs["options"]
@ -199,7 +200,7 @@ class _OllamaCommon(BaseLanguageModel):
params["options"] = { params["options"] = {
**params["options"], **params["options"],
"stop": stop, "stop": stop,
**kwargs, **{k: v for k, v in kwargs.items() if k not in self._default_params},
} }
if payload.get("messages"): if payload.get("messages"):
@ -253,8 +254,9 @@ class _OllamaCommon(BaseLanguageModel):
params = self._default_params params = self._default_params
if "model" in kwargs: for key in self._default_params:
params["model"] = kwargs["model"] if key in kwargs:
params[key] = kwargs[key]
if "options" in kwargs: if "options" in kwargs:
params["options"] = kwargs["options"] params["options"] = kwargs["options"]
@ -262,7 +264,7 @@ class _OllamaCommon(BaseLanguageModel):
params["options"] = { params["options"] = {
**params["options"], **params["options"],
"stop": stop, "stop": stop,
**kwargs, **{k: v for k, v in kwargs.items() if k not in self._default_params},
} }
if payload.get("messages"): if payload.get("messages"):

View File

@ -31,7 +31,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
timeout=300, timeout=300,
) )
def mockPost(url, headers, json, stream, timeout): def mock_post(url, headers, json, stream, timeout):
assert url == "https://ollama-hostname:8000/api/generate/" assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -44,7 +44,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
return mock_response_stream() return mock_response_stream()
monkeypatch.setattr(requests, "post", mockPost) monkeypatch.setattr(requests, "post", mock_post)
llm("Test prompt") llm("Test prompt")
@ -52,7 +52,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None: def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mockPost(url, headers, json, stream, timeout): def mock_post(url, headers, json, stream, timeout):
assert url == "https://ollama-hostname:8000/api/generate/" assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == { assert headers == {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -63,6 +63,131 @@ def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
return mock_response_stream() return mock_response_stream()
monkeypatch.setattr(requests, "post", mockPost) monkeypatch.setattr(requests, "post", mock_post)
llm("Test prompt") llm("Test prompt")
def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
"""Test that top level params are sent to the endpoint as top level params"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout):
assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == {
"Content-Type": "application/json",
}
assert json == {
"format": None,
"images": None,
"model": "test-model",
"options": {
"mirostat": None,
"mirostat_eta": None,
"mirostat_tau": None,
"num_ctx": None,
"num_gpu": None,
"num_thread": None,
"repeat_last_n": None,
"repeat_penalty": None,
"stop": [],
"temperature": None,
"tfs_z": None,
"top_k": None,
"top_p": None,
},
"prompt": "Test prompt",
"system": "Test system prompt",
"template": None,
}
assert stream is True
assert timeout == 300
return mock_response_stream()
monkeypatch.setattr(requests, "post", mock_post)
llm("Test prompt", model="test-model", system="Test system prompt")
def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
"""
Test that params that are not top level params will be sent to the endpoint
as options
"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout):
assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == {
"Content-Type": "application/json",
}
assert json == {
"format": None,
"images": None,
"model": "foo",
"options": {
"mirostat": None,
"mirostat_eta": None,
"mirostat_tau": None,
"num_ctx": None,
"num_gpu": None,
"num_thread": None,
"repeat_last_n": None,
"repeat_penalty": None,
"stop": [],
"temperature": 0.8,
"tfs_z": None,
"top_k": None,
"top_p": None,
"unknown": "Unknown parameter value",
},
"prompt": "Test prompt",
"system": None,
"template": None,
}
assert stream is True
assert timeout == 300
return mock_response_stream()
monkeypatch.setattr(requests, "post", mock_post)
llm("Test prompt", unknown="Unknown parameter value", temperature=0.8)
def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
"""
Test that if options provided it will be sent to the endpoint as options,
ignoring other params that are not top level params.
"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout):
assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == {
"Content-Type": "application/json",
}
assert json == {
"format": None,
"images": None,
"model": "test-another-model",
"options": {"unknown_option": "Unknown option value"},
"prompt": "Test prompt",
"system": None,
"template": None,
}
assert stream is True
assert timeout == 300
return mock_response_stream()
monkeypatch.setattr(requests, "post", mock_post)
llm(
"Test prompt",
model="test-another-model",
options={"unknown_option": "Unknown option value"},
unknown="Unknown parameter value",
temperature=0.8,
)