ollama: include kwargs in requests (#28299)

courtesy of @ryanmagnuson
This commit is contained in:
Erick Friis 2024-11-22 14:15:42 -08:00 committed by GitHub
parent 2ee37a1c7b
commit 7277794a59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 108 deletions

View File

@ -341,13 +341,22 @@ class ChatOllama(BaseChatModel):
The async client to use for making requests.
"""
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama."""
return {
"model": self.model,
"format": self.format,
"options": {
def _chat_params(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
ollama_messages = self._convert_messages_to_ollama_messages(messages)
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
options_dict = kwargs.pop(
"options",
{
"mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau,
@ -359,14 +368,31 @@ class ChatOllama(BaseChatModel):
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
"seed": self.seed,
"stop": self.stop,
"stop": self.stop if stop is None else stop,
"tfs_z": self.tfs_z,
"top_k": self.top_k,
"top_p": self.top_p,
},
"keep_alive": self.keep_alive,
)
tools = kwargs.get("tools")
default_stream = not bool(tools)
params = {
"messages": ollama_messages,
"stream": kwargs.pop("stream", default_stream),
"model": kwargs.pop("model", self.model),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
**kwargs,
}
if tools:
params["tools"] = tools
return params
@model_validator(mode="after")
def _set_clients(self) -> Self:
"""Set clients to use for ollama."""
@ -464,34 +490,9 @@ class ChatOllama(BaseChatModel):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
ollama_messages = self._convert_messages_to_ollama_messages(messages)
chat_params = self._chat_params(messages, stop, **kwargs)
stop = stop if stop is not None else self.stop
params = self._default_params
for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]
params["options"]["stop"] = stop
tools = kwargs.get("tools", None)
stream = tools is None or len(tools) == 0
chat_params = {
"model": params["model"],
"messages": ollama_messages,
"stream": stream,
"options": Options(**params["options"]),
"keep_alive": params["keep_alive"],
"format": params["format"],
}
if tools is not None:
chat_params["tools"] = tools
if stream:
if chat_params["stream"]:
async for part in await self._async_client.chat(**chat_params):
yield part
else:
@ -503,34 +504,9 @@ class ChatOllama(BaseChatModel):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]:
ollama_messages = self._convert_messages_to_ollama_messages(messages)
chat_params = self._chat_params(messages, stop, **kwargs)
stop = stop if stop is not None else self.stop
params = self._default_params
for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]
params["options"]["stop"] = stop
tools = kwargs.get("tools", None)
stream = tools is None or len(tools) == 0
chat_params = {
"model": params["model"],
"messages": ollama_messages,
"stream": stream,
"options": Options(**params["options"]),
"keep_alive": params["keep_alive"],
"format": params["format"],
}
if tools is not None:
chat_params["tools"] = tools
if stream:
if chat_params["stream"]:
yield from self._client.chat(**chat_params)
else:
yield self._client.chat(**chat_params)

View File

@ -126,13 +126,20 @@ class OllamaLLM(BaseLLM):
The async client to use for making requests.
"""
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama."""
return {
"model": self.model,
"format": self.format,
"options": {
def _generate_params(
self,
prompt: str,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
options_dict = kwargs.pop(
"options",
{
"mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau,
@ -143,14 +150,25 @@ class OllamaLLM(BaseLLM):
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
"stop": self.stop,
"stop": self.stop if stop is None else stop,
"tfs_z": self.tfs_z,
"top_k": self.top_k,
"top_p": self.top_p,
},
"keep_alive": self.keep_alive,
)
params = {
"prompt": prompt,
"stream": kwargs.pop("stream", True),
"model": kwargs.pop("model", self.model),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
**kwargs,
}
return params
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
@ -179,25 +197,8 @@ class OllamaLLM(BaseLLM):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
params = self._default_params
for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]
params["options"]["stop"] = stop
async for part in await self._async_client.generate(
model=params["model"],
prompt=prompt,
stream=True,
options=Options(**params["options"]),
keep_alive=params["keep_alive"],
format=params["format"],
**self._generate_params(prompt, stop=stop, **kwargs)
): # type: ignore
yield part
@ -207,25 +208,8 @@ class OllamaLLM(BaseLLM):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[Union[Mapping[str, Any], str]]:
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
params = self._default_params
for key in self._default_params:
if key in kwargs:
params[key] = kwargs[key]
params["options"]["stop"] = stop
yield from self._client.generate(
model=params["model"],
prompt=prompt,
stream=True,
options=Options(**params["options"]),
keep_alive=params["keep_alive"],
format=params["format"],
**self._generate_params(prompt, stop=stop, **kwargs)
)
async def _astream_with_aggregation(