diff --git a/libs/community/langchain_community/llms/ollama.py b/libs/community/langchain_community/llms/ollama.py index 4e7f1838f75..078edcf0928 100644 --- a/libs/community/langchain_community/llms/ollama.py +++ b/libs/community/langchain_community/llms/ollama.py @@ -190,8 +190,9 @@ class _OllamaCommon(BaseLanguageModel): params = self._default_params - if "model" in kwargs: - params["model"] = kwargs["model"] + for key in self._default_params: + if key in kwargs: + params[key] = kwargs[key] if "options" in kwargs: params["options"] = kwargs["options"] @@ -199,7 +200,7 @@ class _OllamaCommon(BaseLanguageModel): params["options"] = { **params["options"], "stop": stop, - **kwargs, + **{k: v for k, v in kwargs.items() if k not in self._default_params}, } if payload.get("messages"): @@ -253,8 +254,9 @@ class _OllamaCommon(BaseLanguageModel): params = self._default_params - if "model" in kwargs: - params["model"] = kwargs["model"] + for key in self._default_params: + if key in kwargs: + params[key] = kwargs[key] if "options" in kwargs: params["options"] = kwargs["options"] @@ -262,7 +264,7 @@ class _OllamaCommon(BaseLanguageModel): params["options"] = { **params["options"], "stop": stop, - **kwargs, + **{k: v for k, v in kwargs.items() if k not in self._default_params}, } if payload.get("messages"): diff --git a/libs/community/tests/unit_tests/llms/test_ollama.py b/libs/community/tests/unit_tests/llms/test_ollama.py index ff18daacfa6..bf2229b4fcd 100644 --- a/libs/community/tests/unit_tests/llms/test_ollama.py +++ b/libs/community/tests/unit_tests/llms/test_ollama.py @@ -31,7 +31,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: timeout=300, ) - def mockPost(url, headers, json, stream, timeout): + def mock_post(url, headers, json, stream, timeout): assert url == "https://ollama-hostname:8000/api/generate/" assert headers == { "Content-Type": "application/json", @@ -44,7 +44,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: return mock_response_stream() - monkeypatch.setattr(requests, "post", mockPost) + monkeypatch.setattr(requests, "post", mock_post) llm("Test prompt") @@ -52,7 +52,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None: llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mockPost(url, headers, json, stream, timeout): + def mock_post(url, headers, json, stream, timeout): assert url == "https://ollama-hostname:8000/api/generate/" assert headers == { "Content-Type": "application/json", @@ -63,6 +63,131 @@ def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None: return mock_response_stream() - monkeypatch.setattr(requests, "post", mockPost) + monkeypatch.setattr(requests, "post", mock_post) llm("Test prompt") + + +def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None: + """Test that top level params are sent to the endpoint as top level params""" + llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) + + def mock_post(url, headers, json, stream, timeout): + assert url == "https://ollama-hostname:8000/api/generate/" + assert headers == { + "Content-Type": "application/json", + } + assert json == { + "format": None, + "images": None, + "model": "test-model", + "options": { + "mirostat": None, + "mirostat_eta": None, + "mirostat_tau": None, + "num_ctx": None, + "num_gpu": None, + "num_thread": None, + "repeat_last_n": None, + "repeat_penalty": None, + "stop": [], + "temperature": None, + "tfs_z": None, + "top_k": None, + "top_p": None, + }, + "prompt": "Test prompt", + "system": "Test system prompt", + "template": None, + } + assert stream is True + assert timeout == 300 + + return mock_response_stream() + + monkeypatch.setattr(requests, "post", mock_post) + + llm("Test prompt", model="test-model", system="Test system prompt") + + +def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None: + """ + Test that params that are not top level params will be sent to the endpoint + as options + """ + llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) + + def mock_post(url, headers, json, stream, timeout): + assert url == "https://ollama-hostname:8000/api/generate/" + assert headers == { + "Content-Type": "application/json", + } + assert json == { + "format": None, + "images": None, + "model": "foo", + "options": { + "mirostat": None, + "mirostat_eta": None, + "mirostat_tau": None, + "num_ctx": None, + "num_gpu": None, + "num_thread": None, + "repeat_last_n": None, + "repeat_penalty": None, + "stop": [], + "temperature": 0.8, + "tfs_z": None, + "top_k": None, + "top_p": None, + "unknown": "Unknown parameter value", + }, + "prompt": "Test prompt", + "system": None, + "template": None, + } + assert stream is True + assert timeout == 300 + + return mock_response_stream() + + monkeypatch.setattr(requests, "post", mock_post) + + llm("Test prompt", unknown="Unknown parameter value", temperature=0.8) + + +def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None: + """ + Test that if options provided it will be sent to the endpoint as options, + ignoring other params that are not top level params. + """ + llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) + + def mock_post(url, headers, json, stream, timeout): + assert url == "https://ollama-hostname:8000/api/generate/" + assert headers == { + "Content-Type": "application/json", + } + assert json == { + "format": None, + "images": None, + "model": "test-another-model", + "options": {"unknown_option": "Unknown option value"}, + "prompt": "Test prompt", + "system": None, + "template": None, + } + assert stream is True + assert timeout == 300 + + return mock_response_stream() + + monkeypatch.setattr(requests, "post", mock_post) + + llm( + "Test prompt", + model="test-another-model", + options={"unknown_option": "Unknown option value"}, + unknown="Unknown parameter value", + temperature=0.8, + )