ollama: include kwargs in requests (#28299)

courtesy of @ryanmagnuson
This commit is contained in:
Erick Friis
2024-11-22 14:15:42 -08:00
committed by GitHub
parent 2ee37a1c7b
commit 7277794a59
2 changed files with 68 additions and 108 deletions

View File

@@ -126,13 +126,20 @@ class OllamaLLM(BaseLLM):
The async client to use for making requests.
"""
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama."""
return {
"model": self.model,
"format": self.format,
"options": {
def _generate_params(
self,
prompt: str,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
options_dict = kwargs.pop(
"options",
{
"mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau,
@@ -143,14 +150,25 @@ class OllamaLLM(BaseLLM):
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
"stop": self.stop,
"stop": self.stop if stop is None else stop,
"tfs_z": self.tfs_z,
"top_k": self.top_k,
"top_p": self.top_p,
},
"keep_alive": self.keep_alive,
)
params = {
"prompt": prompt,
"stream": kwargs.pop("stream", True),
"model": kwargs.pop("model", self.model),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
**kwargs,
}
return params
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
@@ -179,25 +197,8 @@ class OllamaLLM(BaseLLM):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
params = self._default_params
for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]
params["options"]["stop"] = stop
async for part in await self._async_client.generate(
model=params["model"],
prompt=prompt,
stream=True,
options=Options(**params["options"]),
keep_alive=params["keep_alive"],
format=params["format"],
**self._generate_params(prompt, stop=stop, **kwargs)
): # type: ignore
yield part
@@ -207,25 +208,8 @@ class OllamaLLM(BaseLLM):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
params = self._default_params
for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]
params["options"]["stop"] = stop
yield from self._client.generate(
model=params["model"],
prompt=prompt,
stream=True,
options=Options(**params["options"]),
keep_alive=params["keep_alive"],
format=params["format"],
**self._generate_params(prompt, stop=stop, **kwargs)
)
async def _astream_with_aggregation(