mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
parent
2ee37a1c7b
commit
7277794a59
@ -341,13 +341,22 @@ class ChatOllama(BaseChatModel):
|
||||
The async client to use for making requests.
|
||||
"""
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Ollama."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"format": self.format,
|
||||
"options": {
|
||||
def _chat_params(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
ollama_messages = self._convert_messages_to_ollama_messages(messages)
|
||||
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
stop = self.stop
|
||||
|
||||
options_dict = kwargs.pop(
|
||||
"options",
|
||||
{
|
||||
"mirostat": self.mirostat,
|
||||
"mirostat_eta": self.mirostat_eta,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
@ -359,14 +368,31 @@ class ChatOllama(BaseChatModel):
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
"temperature": self.temperature,
|
||||
"seed": self.seed,
|
||||
"stop": self.stop,
|
||||
"stop": self.stop if stop is None else stop,
|
||||
"tfs_z": self.tfs_z,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
},
|
||||
"keep_alive": self.keep_alive,
|
||||
)
|
||||
|
||||
tools = kwargs.get("tools")
|
||||
default_stream = not bool(tools)
|
||||
|
||||
params = {
|
||||
"messages": ollama_messages,
|
||||
"stream": kwargs.pop("stream", default_stream),
|
||||
"model": kwargs.pop("model", self.model),
|
||||
"format": kwargs.pop("format", self.format),
|
||||
"options": Options(**options_dict),
|
||||
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
|
||||
return params
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_clients(self) -> Self:
|
||||
"""Set clients to use for ollama."""
|
||||
@ -464,34 +490,9 @@ class ChatOllama(BaseChatModel):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
|
||||
ollama_messages = self._convert_messages_to_ollama_messages(messages)
|
||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||
|
||||
stop = stop if stop is not None else self.stop
|
||||
|
||||
params = self._default_params
|
||||
|
||||
for key in self._default_params:
|
||||
if key in kwargs:
|
||||
params[key] = kwargs[key]
|
||||
|
||||
params["options"]["stop"] = stop
|
||||
|
||||
tools = kwargs.get("tools", None)
|
||||
stream = tools is None or len(tools) == 0
|
||||
|
||||
chat_params = {
|
||||
"model": params["model"],
|
||||
"messages": ollama_messages,
|
||||
"stream": stream,
|
||||
"options": Options(**params["options"]),
|
||||
"keep_alive": params["keep_alive"],
|
||||
"format": params["format"],
|
||||
}
|
||||
|
||||
if tools is not None:
|
||||
chat_params["tools"] = tools
|
||||
|
||||
if stream:
|
||||
if chat_params["stream"]:
|
||||
async for part in await self._async_client.chat(**chat_params):
|
||||
yield part
|
||||
else:
|
||||
@ -503,34 +504,9 @@ class ChatOllama(BaseChatModel):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Union[Mapping[str, Any], str]]:
|
||||
ollama_messages = self._convert_messages_to_ollama_messages(messages)
|
||||
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||
|
||||
stop = stop if stop is not None else self.stop
|
||||
|
||||
params = self._default_params
|
||||
|
||||
for key in self._default_params:
|
||||
if key in kwargs:
|
||||
params[key] = kwargs[key]
|
||||
|
||||
params["options"]["stop"] = stop
|
||||
|
||||
tools = kwargs.get("tools", None)
|
||||
stream = tools is None or len(tools) == 0
|
||||
|
||||
chat_params = {
|
||||
"model": params["model"],
|
||||
"messages": ollama_messages,
|
||||
"stream": stream,
|
||||
"options": Options(**params["options"]),
|
||||
"keep_alive": params["keep_alive"],
|
||||
"format": params["format"],
|
||||
}
|
||||
|
||||
if tools is not None:
|
||||
chat_params["tools"] = tools
|
||||
|
||||
if stream:
|
||||
if chat_params["stream"]:
|
||||
yield from self._client.chat(**chat_params)
|
||||
else:
|
||||
yield self._client.chat(**chat_params)
|
||||
|
@ -126,13 +126,20 @@ class OllamaLLM(BaseLLM):
|
||||
The async client to use for making requests.
|
||||
"""
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Ollama."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"format": self.format,
|
||||
"options": {
|
||||
def _generate_params(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
stop = self.stop
|
||||
|
||||
options_dict = kwargs.pop(
|
||||
"options",
|
||||
{
|
||||
"mirostat": self.mirostat,
|
||||
"mirostat_eta": self.mirostat_eta,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
@ -143,14 +150,25 @@ class OllamaLLM(BaseLLM):
|
||||
"repeat_last_n": self.repeat_last_n,
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
"temperature": self.temperature,
|
||||
"stop": self.stop,
|
||||
"stop": self.stop if stop is None else stop,
|
||||
"tfs_z": self.tfs_z,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
},
|
||||
"keep_alive": self.keep_alive,
|
||||
)
|
||||
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"stream": kwargs.pop("stream", True),
|
||||
"model": kwargs.pop("model", self.model),
|
||||
"format": kwargs.pop("format", self.format),
|
||||
"options": Options(**options_dict),
|
||||
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
return params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of LLM."""
|
||||
@ -179,25 +197,8 @@ class OllamaLLM(BaseLLM):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
stop = self.stop
|
||||
|
||||
params = self._default_params
|
||||
|
||||
for key in self._default_params:
|
||||
if key in kwargs:
|
||||
params[key] = kwargs[key]
|
||||
|
||||
params["options"]["stop"] = stop
|
||||
async for part in await self._async_client.generate(
|
||||
model=params["model"],
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
options=Options(**params["options"]),
|
||||
keep_alive=params["keep_alive"],
|
||||
format=params["format"],
|
||||
**self._generate_params(prompt, stop=stop, **kwargs)
|
||||
): # type: ignore
|
||||
yield part
|
||||
|
||||
@ -207,25 +208,8 @@ class OllamaLLM(BaseLLM):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Union[Mapping[str, Any], str]]:
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
stop = self.stop
|
||||
|
||||
params = self._default_params
|
||||
|
||||
for key in self._default_params:
|
||||
if key in kwargs:
|
||||
params[key] = kwargs[key]
|
||||
|
||||
params["options"]["stop"] = stop
|
||||
yield from self._client.generate(
|
||||
model=params["model"],
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
options=Options(**params["options"]),
|
||||
keep_alive=params["keep_alive"],
|
||||
format=params["format"],
|
||||
**self._generate_params(prompt, stop=stop, **kwargs)
|
||||
)
|
||||
|
||||
async def _astream_with_aggregation(
|
||||
|
Loading…
Reference in New Issue
Block a user