mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
@@ -126,13 +126,20 @@ class OllamaLLM(BaseLLM):
|
||||
The async client to use for making requests.
|
||||
"""
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Ollama."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"format": self.format,
|
||||
"options": {
|
||||
def _generate_params(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
stop = self.stop
|
||||
|
||||
options_dict = kwargs.pop(
|
||||
"options",
|
||||
{
|
||||
"mirostat": self.mirostat,
|
||||
"mirostat_eta": self.mirostat_eta,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
@@ -143,14 +150,25 @@ class OllamaLLM(BaseLLM):
|
||||
"repeat_last_n": self.repeat_last_n,
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
"temperature": self.temperature,
|
||||
"stop": self.stop,
|
||||
"stop": self.stop if stop is None else stop,
|
||||
"tfs_z": self.tfs_z,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
},
|
||||
"keep_alive": self.keep_alive,
|
||||
)
|
||||
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"stream": kwargs.pop("stream", True),
|
||||
"model": kwargs.pop("model", self.model),
|
||||
"format": kwargs.pop("format", self.format),
|
||||
"options": Options(**options_dict),
|
||||
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
return params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of LLM."""
|
||||
@@ -179,25 +197,8 @@ class OllamaLLM(BaseLLM):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Union[Mapping[str, Any], str]]:
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
stop = self.stop
|
||||
|
||||
params = self._default_params
|
||||
|
||||
for key in self._default_params:
|
||||
if key in kwargs:
|
||||
params[key] = kwargs[key]
|
||||
|
||||
params["options"]["stop"] = stop
|
||||
async for part in await self._async_client.generate(
|
||||
model=params["model"],
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
options=Options(**params["options"]),
|
||||
keep_alive=params["keep_alive"],
|
||||
format=params["format"],
|
||||
**self._generate_params(prompt, stop=stop, **kwargs)
|
||||
): # type: ignore
|
||||
yield part
|
||||
|
||||
@@ -207,25 +208,8 @@ class OllamaLLM(BaseLLM):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Union[Mapping[str, Any], str]]:
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
stop = self.stop
|
||||
|
||||
params = self._default_params
|
||||
|
||||
for key in self._default_params:
|
||||
if key in kwargs:
|
||||
params[key] = kwargs[key]
|
||||
|
||||
params["options"]["stop"] = stop
|
||||
yield from self._client.generate(
|
||||
model=params["model"],
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
options=Options(**params["options"]),
|
||||
keep_alive=params["keep_alive"],
|
||||
format=params["format"],
|
||||
**self._generate_params(prompt, stop=stop, **kwargs)
|
||||
)
|
||||
|
||||
async def _astream_with_aggregation(
|
||||
|
||||
Reference in New Issue
Block a user