diff --git a/private_gpt/components/llm/custom/ollama.py b/private_gpt/components/llm/custom/ollama.py new file mode 100644 index 00000000..af4d3702 --- /dev/null +++ b/private_gpt/components/llm/custom/ollama.py @@ -0,0 +1,32 @@ +from llama_index.llms.ollama import Ollama +from pydantic import Field + + +class CustomOllama(Ollama): + """Custom llama_index Ollama class with the only intention of passing on the keep_alive parameter.""" + + keep_alive: str = Field( + default="5m", + description="String that describes the time the model should stay in (V)RAM after last request.", + ) + + def __init__(self, *args, **kwargs) -> None: + keep_alive = kwargs.pop('keep_alive', '5m') # fetch keep_alive from kwargs or use 5m if not found. + super().__init__(*args, **kwargs) + self.keep_alive = keep_alive + + def chat(self, *args, **kwargs): + kwargs["keep_alive"] = self.keep_alive + return super().chat(*args, **kwargs) + + def stream_chat(self, *args, **kwargs): + kwargs["keep_alive"] = self.keep_alive + return super().stream_chat(*args, **kwargs) + + def complete(self, *args, **kwargs): + kwargs["keep_alive"] = self.keep_alive + return super().complete(*args, **kwargs) + + def stream_complete(self, *args, **kwargs): + kwargs["keep_alive"] = self.keep_alive + return super().stream_complete(*args, **kwargs) diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index 40c33c2d..9b9c46ce 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -108,7 +108,7 @@ class LLMComponent: ) case "ollama": try: - from llama_index.llms.ollama import Ollama # type: ignore + from private_gpt.components.llm.custom.ollama import CustomOllama # type: ignore except ImportError as e: raise ImportError( "Ollama dependencies not found, install with `poetry install --extras llms-ollama`" @@ -125,14 +125,14 @@ class LLMComponent: "repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp } - self.llm = Ollama( + self.llm = CustomOllama( model=ollama_settings.llm_model, base_url=ollama_settings.api_base, temperature=settings.llm.temperature, context_window=settings.llm.context_window, additional_kwargs=settings_kwargs, request_timeout=ollama_settings.request_timeout, - keep_alive = ollama_settings.keep_alive, + keep_alive=ollama_settings.keep_alive, ) case "azopenai": try: