mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
ollama: allow base_url, headers, and auth to be passed (#25078)
This commit is contained in:
@@ -12,14 +12,14 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import ollama
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.outputs import GenerationChunk, LLMResult
|
||||
from ollama import AsyncClient, Options
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from ollama import AsyncClient, Client, Options
|
||||
|
||||
|
||||
class OllamaLLM(BaseLLM):
|
||||
@@ -107,6 +107,24 @@ class OllamaLLM(BaseLLM):
|
||||
keep_alive: Optional[Union[int, str]] = None
|
||||
"""How long the model will stay loaded into memory."""
|
||||
|
||||
base_url: Optional[str] = None
|
||||
"""Base url the model is hosted under."""
|
||||
|
||||
client_kwargs: Optional[dict] = {}
|
||||
"""Additional kwargs to pass to the httpx Client.
|
||||
For a full list of the params, see [this link](https://pydoc.dev/httpx/latest/httpx.Client.html)
|
||||
"""
|
||||
|
||||
_client: Client = Field(default=None)
|
||||
"""
|
||||
The client to use for making requests.
|
||||
"""
|
||||
|
||||
_async_client: AsyncClient = Field(default=None)
|
||||
"""
|
||||
The async client to use for making requests.
|
||||
"""
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Ollama."""
|
||||
@@ -137,6 +155,15 @@ class OllamaLLM(BaseLLM):
|
||||
"""Return type of LLM."""
|
||||
return "ollama-llm"
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def _set_clients(cls, values: dict) -> dict:
|
||||
"""Set clients to use for ollama."""
|
||||
values["_client"] = Client(host=values["base_url"], **values["client_kwargs"])
|
||||
values["_async_client"] = AsyncClient(
|
||||
host=values["base_url"], **values["client_kwargs"]
|
||||
)
|
||||
return values
|
||||
|
||||
async def _acreate_generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -155,7 +182,7 @@ class OllamaLLM(BaseLLM):
|
||||
params[key] = kwargs[key]
|
||||
|
||||
params["options"]["stop"] = stop
|
||||
async for part in await AsyncClient().generate(
|
||||
async for part in await self._async_client.generate(
|
||||
model=params["model"],
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
@@ -183,7 +210,7 @@ class OllamaLLM(BaseLLM):
|
||||
params[key] = kwargs[key]
|
||||
|
||||
params["options"]["stop"] = stop
|
||||
yield from ollama.generate(
|
||||
yield from self._client.generate(
|
||||
model=params["model"],
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
|
||||
Reference in New Issue
Block a user