actually do something with keep_alive parameter

This commit is contained in:
rboone 2024-03-27 17:17:59 +01:00
parent 2123b82a84
commit e0533a40ed
2 changed files with 35 additions and 3 deletions

View File

@ -0,0 +1,32 @@
from llama_index.llms.ollama import Ollama
from pydantic import Field
class CustomOllama(Ollama):
"""Custom llama_index Ollama class with the only intention of passing on the keep_alive parameter."""
keep_alive: str = Field(
default="5m",
description="String that describes the time the model should stay in (V)RAM after last request.",
)
def __init__(self, *args, **kwargs) -> None:
keep_alive = kwargs.pop('keep_alive', '5m') # fetch keep_alive from kwargs or use 5m if not found.
super().__init__(*args, **kwargs)
self.keep_alive = keep_alive
def chat(self, *args, **kwargs):
kwargs["keep_alive"] = self.keep_alive
return super().chat(*args, **kwargs)
def stream_chat(self, *args, **kwargs):
kwargs["keep_alive"] = self.keep_alive
return super().stream_chat(*args, **kwargs)
def complete(self, *args, **kwargs):
kwargs["keep_alive"] = self.keep_alive
return super().complete(*args, **kwargs)
def stream_complete(self, *args, **kwargs):
kwargs["keep_alive"] = self.keep_alive
return super().stream_complete(*args, **kwargs)

View File

@ -108,7 +108,7 @@ class LLMComponent:
)
case "ollama":
try:
from llama_index.llms.ollama import Ollama # type: ignore
from private_gpt.components.llm.custom.ollama import CustomOllama # type: ignore
except ImportError as e:
raise ImportError(
"Ollama dependencies not found, install with `poetry install --extras llms-ollama`"
@ -125,7 +125,7 @@ class LLMComponent:
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
}
self.llm = Ollama(
self.llm = CustomOllama(
model=ollama_settings.llm_model,
base_url=ollama_settings.api_base,
temperature=settings.llm.temperature,