ollama: include kwargs in requests (#28299)

courtesy of @ryanmagnuson
This commit is contained in:
Erick Friis 2024-11-22 14:15:42 -08:00 committed by GitHub
parent 2ee37a1c7b
commit 7277794a59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 108 deletions

View File

@ -341,13 +341,22 @@ class ChatOllama(BaseChatModel):
The async client to use for making requests. The async client to use for making requests.
""" """
@property def _chat_params(
def _default_params(self) -> Dict[str, Any]: self,
"""Get the default parameters for calling Ollama.""" messages: List[BaseMessage],
return { stop: Optional[List[str]] = None,
"model": self.model, **kwargs: Any,
"format": self.format, ) -> Dict[str, Any]:
"options": { ollama_messages = self._convert_messages_to_ollama_messages(messages)
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
options_dict = kwargs.pop(
"options",
{
"mirostat": self.mirostat, "mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta, "mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau, "mirostat_tau": self.mirostat_tau,
@ -359,14 +368,31 @@ class ChatOllama(BaseChatModel):
"repeat_penalty": self.repeat_penalty, "repeat_penalty": self.repeat_penalty,
"temperature": self.temperature, "temperature": self.temperature,
"seed": self.seed, "seed": self.seed,
"stop": self.stop, "stop": self.stop if stop is None else stop,
"tfs_z": self.tfs_z, "tfs_z": self.tfs_z,
"top_k": self.top_k, "top_k": self.top_k,
"top_p": self.top_p, "top_p": self.top_p,
}, },
"keep_alive": self.keep_alive, )
tools = kwargs.get("tools")
default_stream = not bool(tools)
params = {
"messages": ollama_messages,
"stream": kwargs.pop("stream", default_stream),
"model": kwargs.pop("model", self.model),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
**kwargs,
} }
if tools:
params["tools"] = tools
return params
@model_validator(mode="after") @model_validator(mode="after")
def _set_clients(self) -> Self: def _set_clients(self) -> Self:
"""Set clients to use for ollama.""" """Set clients to use for ollama."""
@ -464,34 +490,9 @@ class ChatOllama(BaseChatModel):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]: ) -> AsyncIterator[Union[Mapping[str, Any], str]]:
ollama_messages = self._convert_messages_to_ollama_messages(messages) chat_params = self._chat_params(messages, stop, **kwargs)
stop = stop if stop is not None else self.stop if chat_params["stream"]:
params = self._default_params
for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]
params["options"]["stop"] = stop
tools = kwargs.get("tools", None)
stream = tools is None or len(tools) == 0
chat_params = {
"model": params["model"],
"messages": ollama_messages,
"stream": stream,
"options": Options(**params["options"]),
"keep_alive": params["keep_alive"],
"format": params["format"],
}
if tools is not None:
chat_params["tools"] = tools
if stream:
async for part in await self._async_client.chat(**chat_params): async for part in await self._async_client.chat(**chat_params):
yield part yield part
else: else:
@ -503,34 +504,9 @@ class ChatOllama(BaseChatModel):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]: ) -> Iterator[Union[Mapping[str, Any], str]]:
ollama_messages = self._convert_messages_to_ollama_messages(messages) chat_params = self._chat_params(messages, stop, **kwargs)
stop = stop if stop is not None else self.stop if chat_params["stream"]:
params = self._default_params
for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]
params["options"]["stop"] = stop
tools = kwargs.get("tools", None)
stream = tools is None or len(tools) == 0
chat_params = {
"model": params["model"],
"messages": ollama_messages,
"stream": stream,
"options": Options(**params["options"]),
"keep_alive": params["keep_alive"],
"format": params["format"],
}
if tools is not None:
chat_params["tools"] = tools
if stream:
yield from self._client.chat(**chat_params) yield from self._client.chat(**chat_params)
else: else:
yield self._client.chat(**chat_params) yield self._client.chat(**chat_params)

View File

@ -126,13 +126,20 @@ class OllamaLLM(BaseLLM):
The async client to use for making requests. The async client to use for making requests.
""" """
@property def _generate_params(
def _default_params(self) -> Dict[str, Any]: self,
"""Get the default parameters for calling Ollama.""" prompt: str,
return { stop: Optional[List[str]] = None,
"model": self.model, **kwargs: Any,
"format": self.format, ) -> Dict[str, Any]:
"options": { if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
options_dict = kwargs.pop(
"options",
{
"mirostat": self.mirostat, "mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta, "mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau, "mirostat_tau": self.mirostat_tau,
@ -143,14 +150,25 @@ class OllamaLLM(BaseLLM):
"repeat_last_n": self.repeat_last_n, "repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty, "repeat_penalty": self.repeat_penalty,
"temperature": self.temperature, "temperature": self.temperature,
"stop": self.stop, "stop": self.stop if stop is None else stop,
"tfs_z": self.tfs_z, "tfs_z": self.tfs_z,
"top_k": self.top_k, "top_k": self.top_k,
"top_p": self.top_p, "top_p": self.top_p,
}, },
"keep_alive": self.keep_alive, )
params = {
"prompt": prompt,
"stream": kwargs.pop("stream", True),
"model": kwargs.pop("model", self.model),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
**kwargs,
} }
return params
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of LLM.""" """Return type of LLM."""
@ -179,25 +197,8 @@ class OllamaLLM(BaseLLM):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]: ) -> AsyncIterator[Union[Mapping[str, Any], str]]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
params = self._default_params
for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]
params["options"]["stop"] = stop
async for part in await self._async_client.generate( async for part in await self._async_client.generate(
model=params["model"], **self._generate_params(prompt, stop=stop, **kwargs)
prompt=prompt,
stream=True,
options=Options(**params["options"]),
keep_alive=params["keep_alive"],
format=params["format"],
): # type: ignore ): # type: ignore
yield part yield part
@ -207,25 +208,8 @@ class OllamaLLM(BaseLLM):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]: ) -> Iterator[Union[Mapping[str, Any], str]]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
params = self._default_params
for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]
params["options"]["stop"] = stop
yield from self._client.generate( yield from self._client.generate(
model=params["model"], **self._generate_params(prompt, stop=stop, **kwargs)
prompt=prompt,
stream=True,
options=Options(**params["options"]),
keep_alive=params["keep_alive"],
format=params["format"],
) )
async def _astream_with_aggregation( async def _astream_with_aggregation(