mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 09:58:44 +00:00
parent
2ee37a1c7b
commit
7277794a59
@ -341,13 +341,22 @@ class ChatOllama(BaseChatModel):
|
|||||||
The async client to use for making requests.
|
The async client to use for making requests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
def _chat_params(
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
self,
|
||||||
"""Get the default parameters for calling Ollama."""
|
messages: List[BaseMessage],
|
||||||
return {
|
stop: Optional[List[str]] = None,
|
||||||
"model": self.model,
|
**kwargs: Any,
|
||||||
"format": self.format,
|
) -> Dict[str, Any]:
|
||||||
"options": {
|
ollama_messages = self._convert_messages_to_ollama_messages(messages)
|
||||||
|
|
||||||
|
if self.stop is not None and stop is not None:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
elif self.stop is not None:
|
||||||
|
stop = self.stop
|
||||||
|
|
||||||
|
options_dict = kwargs.pop(
|
||||||
|
"options",
|
||||||
|
{
|
||||||
"mirostat": self.mirostat,
|
"mirostat": self.mirostat,
|
||||||
"mirostat_eta": self.mirostat_eta,
|
"mirostat_eta": self.mirostat_eta,
|
||||||
"mirostat_tau": self.mirostat_tau,
|
"mirostat_tau": self.mirostat_tau,
|
||||||
@ -359,14 +368,31 @@ class ChatOllama(BaseChatModel):
|
|||||||
"repeat_penalty": self.repeat_penalty,
|
"repeat_penalty": self.repeat_penalty,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"seed": self.seed,
|
"seed": self.seed,
|
||||||
"stop": self.stop,
|
"stop": self.stop if stop is None else stop,
|
||||||
"tfs_z": self.tfs_z,
|
"tfs_z": self.tfs_z,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
},
|
},
|
||||||
"keep_alive": self.keep_alive,
|
)
|
||||||
|
|
||||||
|
tools = kwargs.get("tools")
|
||||||
|
default_stream = not bool(tools)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"messages": ollama_messages,
|
||||||
|
"stream": kwargs.pop("stream", default_stream),
|
||||||
|
"model": kwargs.pop("model", self.model),
|
||||||
|
"format": kwargs.pop("format", self.format),
|
||||||
|
"options": Options(**options_dict),
|
||||||
|
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
|
||||||
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
params["tools"] = tools
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _set_clients(self) -> Self:
|
def _set_clients(self) -> Self:
|
||||||
"""Set clients to use for ollama."""
|
"""Set clients to use for ollama."""
|
||||||
@ -464,34 +490,9 @@ class ChatOllama(BaseChatModel):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
|
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
|
||||||
ollama_messages = self._convert_messages_to_ollama_messages(messages)
|
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||||
|
|
||||||
stop = stop if stop is not None else self.stop
|
if chat_params["stream"]:
|
||||||
|
|
||||||
params = self._default_params
|
|
||||||
|
|
||||||
for key in self._default_params:
|
|
||||||
if key in kwargs:
|
|
||||||
params[key] = kwargs[key]
|
|
||||||
|
|
||||||
params["options"]["stop"] = stop
|
|
||||||
|
|
||||||
tools = kwargs.get("tools", None)
|
|
||||||
stream = tools is None or len(tools) == 0
|
|
||||||
|
|
||||||
chat_params = {
|
|
||||||
"model": params["model"],
|
|
||||||
"messages": ollama_messages,
|
|
||||||
"stream": stream,
|
|
||||||
"options": Options(**params["options"]),
|
|
||||||
"keep_alive": params["keep_alive"],
|
|
||||||
"format": params["format"],
|
|
||||||
}
|
|
||||||
|
|
||||||
if tools is not None:
|
|
||||||
chat_params["tools"] = tools
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
async for part in await self._async_client.chat(**chat_params):
|
async for part in await self._async_client.chat(**chat_params):
|
||||||
yield part
|
yield part
|
||||||
else:
|
else:
|
||||||
@ -503,34 +504,9 @@ class ChatOllama(BaseChatModel):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Union[Mapping[str, Any], str]]:
|
) -> Iterator[Union[Mapping[str, Any], str]]:
|
||||||
ollama_messages = self._convert_messages_to_ollama_messages(messages)
|
chat_params = self._chat_params(messages, stop, **kwargs)
|
||||||
|
|
||||||
stop = stop if stop is not None else self.stop
|
if chat_params["stream"]:
|
||||||
|
|
||||||
params = self._default_params
|
|
||||||
|
|
||||||
for key in self._default_params:
|
|
||||||
if key in kwargs:
|
|
||||||
params[key] = kwargs[key]
|
|
||||||
|
|
||||||
params["options"]["stop"] = stop
|
|
||||||
|
|
||||||
tools = kwargs.get("tools", None)
|
|
||||||
stream = tools is None or len(tools) == 0
|
|
||||||
|
|
||||||
chat_params = {
|
|
||||||
"model": params["model"],
|
|
||||||
"messages": ollama_messages,
|
|
||||||
"stream": stream,
|
|
||||||
"options": Options(**params["options"]),
|
|
||||||
"keep_alive": params["keep_alive"],
|
|
||||||
"format": params["format"],
|
|
||||||
}
|
|
||||||
|
|
||||||
if tools is not None:
|
|
||||||
chat_params["tools"] = tools
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
yield from self._client.chat(**chat_params)
|
yield from self._client.chat(**chat_params)
|
||||||
else:
|
else:
|
||||||
yield self._client.chat(**chat_params)
|
yield self._client.chat(**chat_params)
|
||||||
|
@ -126,13 +126,20 @@ class OllamaLLM(BaseLLM):
|
|||||||
The async client to use for making requests.
|
The async client to use for making requests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
def _generate_params(
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
self,
|
||||||
"""Get the default parameters for calling Ollama."""
|
prompt: str,
|
||||||
return {
|
stop: Optional[List[str]] = None,
|
||||||
"model": self.model,
|
**kwargs: Any,
|
||||||
"format": self.format,
|
) -> Dict[str, Any]:
|
||||||
"options": {
|
if self.stop is not None and stop is not None:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
elif self.stop is not None:
|
||||||
|
stop = self.stop
|
||||||
|
|
||||||
|
options_dict = kwargs.pop(
|
||||||
|
"options",
|
||||||
|
{
|
||||||
"mirostat": self.mirostat,
|
"mirostat": self.mirostat,
|
||||||
"mirostat_eta": self.mirostat_eta,
|
"mirostat_eta": self.mirostat_eta,
|
||||||
"mirostat_tau": self.mirostat_tau,
|
"mirostat_tau": self.mirostat_tau,
|
||||||
@ -143,14 +150,25 @@ class OllamaLLM(BaseLLM):
|
|||||||
"repeat_last_n": self.repeat_last_n,
|
"repeat_last_n": self.repeat_last_n,
|
||||||
"repeat_penalty": self.repeat_penalty,
|
"repeat_penalty": self.repeat_penalty,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"stop": self.stop,
|
"stop": self.stop if stop is None else stop,
|
||||||
"tfs_z": self.tfs_z,
|
"tfs_z": self.tfs_z,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
},
|
},
|
||||||
"keep_alive": self.keep_alive,
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": kwargs.pop("stream", True),
|
||||||
|
"model": kwargs.pop("model", self.model),
|
||||||
|
"format": kwargs.pop("format", self.format),
|
||||||
|
"options": Options(**options_dict),
|
||||||
|
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
|
||||||
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of LLM."""
|
"""Return type of LLM."""
|
||||||
@ -179,25 +197,8 @@ class OllamaLLM(BaseLLM):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
|
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
|
||||||
if self.stop is not None and stop is not None:
|
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
|
||||||
elif self.stop is not None:
|
|
||||||
stop = self.stop
|
|
||||||
|
|
||||||
params = self._default_params
|
|
||||||
|
|
||||||
for key in self._default_params:
|
|
||||||
if key in kwargs:
|
|
||||||
params[key] = kwargs[key]
|
|
||||||
|
|
||||||
params["options"]["stop"] = stop
|
|
||||||
async for part in await self._async_client.generate(
|
async for part in await self._async_client.generate(
|
||||||
model=params["model"],
|
**self._generate_params(prompt, stop=stop, **kwargs)
|
||||||
prompt=prompt,
|
|
||||||
stream=True,
|
|
||||||
options=Options(**params["options"]),
|
|
||||||
keep_alive=params["keep_alive"],
|
|
||||||
format=params["format"],
|
|
||||||
): # type: ignore
|
): # type: ignore
|
||||||
yield part
|
yield part
|
||||||
|
|
||||||
@ -207,25 +208,8 @@ class OllamaLLM(BaseLLM):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Union[Mapping[str, Any], str]]:
|
) -> Iterator[Union[Mapping[str, Any], str]]:
|
||||||
if self.stop is not None and stop is not None:
|
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
|
||||||
elif self.stop is not None:
|
|
||||||
stop = self.stop
|
|
||||||
|
|
||||||
params = self._default_params
|
|
||||||
|
|
||||||
for key in self._default_params:
|
|
||||||
if key in kwargs:
|
|
||||||
params[key] = kwargs[key]
|
|
||||||
|
|
||||||
params["options"]["stop"] = stop
|
|
||||||
yield from self._client.generate(
|
yield from self._client.generate(
|
||||||
model=params["model"],
|
**self._generate_params(prompt, stop=stop, **kwargs)
|
||||||
prompt=prompt,
|
|
||||||
stream=True,
|
|
||||||
options=Options(**params["options"]),
|
|
||||||
keep_alive=params["keep_alive"],
|
|
||||||
format=params["format"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _astream_with_aggregation(
|
async def _astream_with_aggregation(
|
||||||
|
Loading…
Reference in New Issue
Block a user