From 86321a949f06a41a286a204f447de8491facbeaa Mon Sep 17 00:00:00 2001 From: shahrin014 Date: Tue, 16 Jan 2024 03:17:58 +0900 Subject: [PATCH] community: Ollama - Parameter structure to follow official documentation (#16035) ## Feature - Follow parameter structure as per official documentation - top level parameters (e.g. model, system, template) will be passed as top level parameters - other parameters will be sent in options unless options is provided ![image](https://github.com/langchain-ai/langchain/assets/17451563/d14715d9-9701-4ee3-b44b-89fffea62389) ## Tests - Test if top level parameters handled properly - Test if parameters that are not top level parameters are handled as options - Test if options is provided, it will be passed as is --- .../langchain_community/llms/ollama.py | 14 +- .../tests/unit_tests/llms/test_ollama.py | 133 +++++++++++++++++- 2 files changed, 137 insertions(+), 10 deletions(-) diff --git a/libs/community/langchain_community/llms/ollama.py b/libs/community/langchain_community/llms/ollama.py index 4e7f1838f75..078edcf0928 100644 --- a/libs/community/langchain_community/llms/ollama.py +++ b/libs/community/langchain_community/llms/ollama.py @@ -190,8 +190,9 @@ class _OllamaCommon(BaseLanguageModel): params = self._default_params - if "model" in kwargs: - params["model"] = kwargs["model"] + for key in self._default_params: + if key in kwargs: + params[key] = kwargs[key] if "options" in kwargs: params["options"] = kwargs["options"] @@ -199,7 +200,7 @@ class _OllamaCommon(BaseLanguageModel): params["options"] = { **params["options"], "stop": stop, - **kwargs, + **{k: v for k, v in kwargs.items() if k not in self._default_params}, } if payload.get("messages"): @@ -253,8 +254,9 @@ class _OllamaCommon(BaseLanguageModel): params = self._default_params - if "model" in kwargs: - params["model"] = kwargs["model"] + for key in self._default_params: + if key in kwargs: + params[key] = kwargs[key] if "options" in kwargs: params["options"] = kwargs["options"] @@ -262,7 +264,7 @@ class _OllamaCommon(BaseLanguageModel): params["options"] = { **params["options"], "stop": stop, - **kwargs, + **{k: v for k, v in kwargs.items() if k not in self._default_params}, } if payload.get("messages"): diff --git a/libs/community/tests/unit_tests/llms/test_ollama.py b/libs/community/tests/unit_tests/llms/test_ollama.py index ff18daacfa6..bf2229b4fcd 100644 --- a/libs/community/tests/unit_tests/llms/test_ollama.py +++ b/libs/community/tests/unit_tests/llms/test_ollama.py @@ -31,7 +31,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: timeout=300, ) - def mockPost(url, headers, json, stream, timeout): + def mock_post(url, headers, json, stream, timeout): assert url == "https://ollama-hostname:8000/api/generate/" assert headers == { "Content-Type": "application/json", @@ -44,7 +44,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: return mock_response_stream() - monkeypatch.setattr(requests, "post", mockPost) + monkeypatch.setattr(requests, "post", mock_post) llm("Test prompt") @@ -52,7 +52,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None: llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mockPost(url, headers, json, stream, timeout): + def mock_post(url, headers, json, stream, timeout): assert url == "https://ollama-hostname:8000/api/generate/" assert headers == { "Content-Type": "application/json", @@ -63,6 +63,131 @@ def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None: return mock_response_stream() - monkeypatch.setattr(requests, "post", mockPost) + monkeypatch.setattr(requests, "post", mock_post) llm("Test prompt") + + +def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None: + """Test that top level params are sent to the endpoint as top level params""" + llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) + + def mock_post(url, headers, json, stream, timeout): + assert url == "https://ollama-hostname:8000/api/generate/" + assert headers == { + "Content-Type": "application/json", + } + assert json == { + "format": None, + "images": None, + "model": "test-model", + "options": { + "mirostat": None, + "mirostat_eta": None, + "mirostat_tau": None, + "num_ctx": None, + "num_gpu": None, + "num_thread": None, + "repeat_last_n": None, + "repeat_penalty": None, + "stop": [], + "temperature": None, + "tfs_z": None, + "top_k": None, + "top_p": None, + }, + "prompt": "Test prompt", + "system": "Test system prompt", + "template": None, + } + assert stream is True + assert timeout == 300 + + return mock_response_stream() + + monkeypatch.setattr(requests, "post", mock_post) + + llm("Test prompt", model="test-model", system="Test system prompt") + + +def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None: + """ + Test that params that are not top level params will be sent to the endpoint + as options + """ + llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) + + def mock_post(url, headers, json, stream, timeout): + assert url == "https://ollama-hostname:8000/api/generate/" + assert headers == { + "Content-Type": "application/json", + } + assert json == { + "format": None, + "images": None, + "model": "foo", + "options": { + "mirostat": None, + "mirostat_eta": None, + "mirostat_tau": None, + "num_ctx": None, + "num_gpu": None, + "num_thread": None, + "repeat_last_n": None, + "repeat_penalty": None, + "stop": [], + "temperature": 0.8, + "tfs_z": None, + "top_k": None, + "top_p": None, + "unknown": "Unknown parameter value", + }, + "prompt": "Test prompt", + "system": None, + "template": None, + } + assert stream is True + assert timeout == 300 + + return mock_response_stream() + + monkeypatch.setattr(requests, "post", mock_post) + + llm("Test prompt", unknown="Unknown parameter value", temperature=0.8) + + +def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None: + """ + Test that if options provided it will be sent to the endpoint as options, + ignoring other params that are not top level params. + """ + llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) + + def mock_post(url, headers, json, stream, timeout): + assert url == "https://ollama-hostname:8000/api/generate/" + assert headers == { + "Content-Type": "application/json", + } + assert json == { + "format": None, + "images": None, + "model": "test-another-model", + "options": {"unknown_option": "Unknown option value"}, + "prompt": "Test prompt", + "system": None, + "template": None, + } + assert stream is True + assert timeout == 300 + + return mock_response_stream() + + monkeypatch.setattr(requests, "post", mock_post) + + llm( + "Test prompt", + model="test-another-model", + options={"unknown_option": "Unknown option value"}, + unknown="Unknown parameter value", + temperature=0.8, + )