diff --git a/libs/community/langchain_community/llms/ollama.py b/libs/community/langchain_community/llms/ollama.py index 2dc198fec13..c4747a4ebb7 100644 --- a/libs/community/langchain_community/llms/ollama.py +++ b/libs/community/langchain_community/llms/ollama.py @@ -1,5 +1,5 @@ import json -from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional +from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Union import aiohttp import requests @@ -111,6 +111,18 @@ class _OllamaCommon(BaseLanguageModel): timeout: Optional[int] = None """Timeout for the request stream""" + keep_alive: Optional[Union[int, str]] = None + """How long the model will stay loaded into memory. + + The parameter (Default: 5 minutes) can be set to: + 1. a duration string in Golang (such as "10m" or "24h"); + 2. a number in seconds (such as 3600); + 3. any negative number which will keep the model loaded \ + in memory (e.g. -1 or "-1m"); + 4. 0 which will unload the model immediately after generating a response; + + See the [Ollama documents](https://github.com/ollama/ollama/blob/main/docs/faq.md#how-do-i-keep-a-model-loaded-in-memory-or-make-it-unload-immediately)""" + headers: Optional[dict] = None """Additional headers to pass to endpoint (e.g. Authorization, Referer). This is useful when Ollama is hosted on cloud services that require @@ -141,6 +153,7 @@ class _OllamaCommon(BaseLanguageModel): }, "system": self.system, "template": self.template, + "keep_alive": self.keep_alive, } @property diff --git a/libs/community/tests/unit_tests/llms/test_ollama.py b/libs/community/tests/unit_tests/llms/test_ollama.py index 3b1798fd2e9..8807aab826c 100644 --- a/libs/community/tests/unit_tests/llms/test_ollama.py +++ b/libs/community/tests/unit_tests/llms/test_ollama.py @@ -100,6 +100,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None: "prompt": "Test prompt", "system": "Test system prompt", "template": None, + "keep_alive": None, } assert stream is True assert timeout == 300 @@ -147,6 +148,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None: "prompt": "Test prompt", "system": None, "template": None, + "keep_alive": None, } assert stream is True assert timeout == 300 @@ -178,6 +180,7 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None: "prompt": "Test prompt", "system": None, "template": None, + "keep_alive": None, } assert stream is True assert timeout == 300