mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-19 16:07:01 +00:00
actually do something with keep_alive parameter
This commit is contained in:
parent
2123b82a84
commit
e0533a40ed
32
private_gpt/components/llm/custom/ollama.py
Normal file
32
private_gpt/components/llm/custom/ollama.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from llama_index.llms.ollama import Ollama
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
||||||
|
class CustomOllama(Ollama):
|
||||||
|
"""Custom llama_index Ollama class with the only intention of passing on the keep_alive parameter."""
|
||||||
|
|
||||||
|
keep_alive: str = Field(
|
||||||
|
default="5m",
|
||||||
|
description="String that describes the time the model should stay in (V)RAM after last request.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
keep_alive = kwargs.pop('keep_alive', '5m') # fetch keep_alive from kwargs or use 5m if not found.
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.keep_alive = keep_alive
|
||||||
|
|
||||||
|
def chat(self, *args, **kwargs):
|
||||||
|
kwargs["keep_alive"] = self.keep_alive
|
||||||
|
return super().chat(*args, **kwargs)
|
||||||
|
|
||||||
|
def stream_chat(self, *args, **kwargs):
|
||||||
|
kwargs["keep_alive"] = self.keep_alive
|
||||||
|
return super().stream_chat(*args, **kwargs)
|
||||||
|
|
||||||
|
def complete(self, *args, **kwargs):
|
||||||
|
kwargs["keep_alive"] = self.keep_alive
|
||||||
|
return super().complete(*args, **kwargs)
|
||||||
|
|
||||||
|
def stream_complete(self, *args, **kwargs):
|
||||||
|
kwargs["keep_alive"] = self.keep_alive
|
||||||
|
return super().stream_complete(*args, **kwargs)
|
@ -108,7 +108,7 @@ class LLMComponent:
|
|||||||
)
|
)
|
||||||
case "ollama":
|
case "ollama":
|
||||||
try:
|
try:
|
||||||
from llama_index.llms.ollama import Ollama # type: ignore
|
from private_gpt.components.llm.custom.ollama import CustomOllama # type: ignore
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Ollama dependencies not found, install with `poetry install --extras llms-ollama`"
|
"Ollama dependencies not found, install with `poetry install --extras llms-ollama`"
|
||||||
@ -125,14 +125,14 @@ class LLMComponent:
|
|||||||
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
|
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
|
||||||
}
|
}
|
||||||
|
|
||||||
self.llm = Ollama(
|
self.llm = CustomOllama(
|
||||||
model=ollama_settings.llm_model,
|
model=ollama_settings.llm_model,
|
||||||
base_url=ollama_settings.api_base,
|
base_url=ollama_settings.api_base,
|
||||||
temperature=settings.llm.temperature,
|
temperature=settings.llm.temperature,
|
||||||
context_window=settings.llm.context_window,
|
context_window=settings.llm.context_window,
|
||||||
additional_kwargs=settings_kwargs,
|
additional_kwargs=settings_kwargs,
|
||||||
request_timeout=ollama_settings.request_timeout,
|
request_timeout=ollama_settings.request_timeout,
|
||||||
keep_alive = ollama_settings.keep_alive,
|
keep_alive=ollama_settings.keep_alive,
|
||||||
)
|
)
|
||||||
case "azopenai":
|
case "azopenai":
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user