ollama: allow base_url, headers, and auth to be passed (#25078)

This commit is contained in:
Isaac Francisco
2024-08-05 15:39:36 -07:00
committed by GitHub
parent 4bcd2aad6c
commit 63ddf0afb4
3 changed files with 97 additions and 14 deletions

View File

@@ -35,6 +35,7 @@ from langchain_core.messages import (
from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call from langchain_core.messages.tool import tool_call
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import Runnable from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
@@ -322,6 +323,21 @@ class ChatOllama(BaseChatModel):
base_url: Optional[str] = None base_url: Optional[str] = None
"""Base url the model is hosted under.""" """Base url the model is hosted under."""
client_kwargs: Optional[dict] = {}
"""Additional kwargs to pass to the httpx Client.
For a full list of the params, see [this link](https://pydoc.dev/httpx/latest/httpx.Client.html)
"""
_client: Client = Field(default=None)
"""
The client to use for making requests.
"""
_async_client: AsyncClient = Field(default=None)
"""
The async client to use for making requests.
"""
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama.""" """Get the default parameters for calling Ollama."""
@@ -348,6 +364,15 @@ class ChatOllama(BaseChatModel):
"keep_alive": self.keep_alive, "keep_alive": self.keep_alive,
} }
@root_validator(pre=False, skip_on_failure=True)
def _set_clients(cls, values: dict) -> dict:
"""Set clients to use for ollama."""
values["_client"] = Client(host=values["base_url"], **values["client_kwargs"])
values["_async_client"] = AsyncClient(
host=values["base_url"], **values["client_kwargs"]
)
return values
def _convert_messages_to_ollama_messages( def _convert_messages_to_ollama_messages(
self, messages: List[BaseMessage] self, messages: List[BaseMessage]
) -> Sequence[Message]: ) -> Sequence[Message]:
@@ -449,7 +474,7 @@ class ChatOllama(BaseChatModel):
params["options"]["stop"] = stop params["options"]["stop"] = stop
if "tools" in kwargs: if "tools" in kwargs:
yield await AsyncClient(host=self.base_url).chat( yield await self._async_client.chat(
model=params["model"], model=params["model"],
messages=ollama_messages, messages=ollama_messages,
stream=False, stream=False,
@@ -459,7 +484,7 @@ class ChatOllama(BaseChatModel):
tools=kwargs["tools"], tools=kwargs["tools"],
) # type:ignore ) # type:ignore
else: else:
async for part in await AsyncClient(host=self.base_url).chat( async for part in await self._async_client.chat(
model=params["model"], model=params["model"],
messages=ollama_messages, messages=ollama_messages,
stream=True, stream=True,
@@ -487,7 +512,7 @@ class ChatOllama(BaseChatModel):
params["options"]["stop"] = stop params["options"]["stop"] = stop
if "tools" in kwargs: if "tools" in kwargs:
yield Client(host=self.base_url).chat( yield self._client.chat(
model=params["model"], model=params["model"],
messages=ollama_messages, messages=ollama_messages,
stream=False, stream=False,
@@ -497,7 +522,7 @@ class ChatOllama(BaseChatModel):
tools=kwargs["tools"], tools=kwargs["tools"],
) )
else: else:
yield from Client(host=self.base_url).chat( yield from self._client.chat(
model=params["model"], model=params["model"],
messages=ollama_messages, messages=ollama_messages,
stream=True, stream=True,

View File

@@ -1,9 +1,11 @@
from typing import List from typing import (
List,
Optional,
)
import ollama
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from ollama import AsyncClient from ollama import AsyncClient, Client
class OllamaEmbeddings(BaseModel, Embeddings): class OllamaEmbeddings(BaseModel, Embeddings):
@@ -21,14 +23,41 @@ class OllamaEmbeddings(BaseModel, Embeddings):
model: str model: str
"""Model name to use.""" """Model name to use."""
base_url: Optional[str] = None
"""Base url the model is hosted under."""
client_kwargs: Optional[dict] = {}
"""Additional kwargs to pass to the httpx Client.
For a full list of the params, see [this link](https://pydoc.dev/httpx/latest/httpx.Client.html)
"""
_client: Client = Field(default=None)
"""
The client to use for making requests.
"""
_async_client: AsyncClient = Field(default=None)
"""
The async client to use for making requests.
"""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.forbid extra = Extra.forbid
@root_validator(pre=False, skip_on_failure=True)
def _set_clients(cls, values: dict) -> dict:
"""Set clients to use for ollama."""
values["_client"] = Client(host=values["base_url"], **values["client_kwargs"])
values["_async_client"] = AsyncClient(
host=values["base_url"], **values["client_kwargs"]
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs.""" """Embed search docs."""
embedded_docs = ollama.embed(self.model, texts)["embeddings"] embedded_docs = self._client.embed(self.model, texts)["embeddings"]
return embedded_docs return embedded_docs
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
@@ -37,7 +66,9 @@ class OllamaEmbeddings(BaseModel, Embeddings):
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs.""" """Embed search docs."""
embedded_docs = (await AsyncClient().embed(self.model, texts))["embeddings"] embedded_docs = (await self._async_client.embed(self.model, texts))[
"embeddings"
]
return embedded_docs return embedded_docs
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> List[float]:

View File

@@ -12,14 +12,14 @@ from typing import (
Union, Union,
) )
import ollama
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models import BaseLLM from langchain_core.language_models import BaseLLM
from langchain_core.outputs import GenerationChunk, LLMResult from langchain_core.outputs import GenerationChunk, LLMResult
from ollama import AsyncClient, Options from langchain_core.pydantic_v1 import Field, root_validator
from ollama import AsyncClient, Client, Options
class OllamaLLM(BaseLLM): class OllamaLLM(BaseLLM):
@@ -107,6 +107,24 @@ class OllamaLLM(BaseLLM):
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
"""How long the model will stay loaded into memory.""" """How long the model will stay loaded into memory."""
base_url: Optional[str] = None
"""Base url the model is hosted under."""
client_kwargs: Optional[dict] = {}
"""Additional kwargs to pass to the httpx Client.
For a full list of the params, see [this link](https://pydoc.dev/httpx/latest/httpx.Client.html)
"""
_client: Client = Field(default=None)
"""
The client to use for making requests.
"""
_async_client: AsyncClient = Field(default=None)
"""
The async client to use for making requests.
"""
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama.""" """Get the default parameters for calling Ollama."""
@@ -137,6 +155,15 @@ class OllamaLLM(BaseLLM):
"""Return type of LLM.""" """Return type of LLM."""
return "ollama-llm" return "ollama-llm"
@root_validator(pre=False, skip_on_failure=True)
def _set_clients(cls, values: dict) -> dict:
"""Set clients to use for ollama."""
values["_client"] = Client(host=values["base_url"], **values["client_kwargs"])
values["_async_client"] = AsyncClient(
host=values["base_url"], **values["client_kwargs"]
)
return values
async def _acreate_generate_stream( async def _acreate_generate_stream(
self, self,
prompt: str, prompt: str,
@@ -155,7 +182,7 @@ class OllamaLLM(BaseLLM):
params[key] = kwargs[key] params[key] = kwargs[key]
params["options"]["stop"] = stop params["options"]["stop"] = stop
async for part in await AsyncClient().generate( async for part in await self._async_client.generate(
model=params["model"], model=params["model"],
prompt=prompt, prompt=prompt,
stream=True, stream=True,
@@ -183,7 +210,7 @@ class OllamaLLM(BaseLLM):
params[key] = kwargs[key] params[key] = kwargs[key]
params["options"]["stop"] = stop params["options"]["stop"] = stop
yield from ollama.generate( yield from self._client.generate(
model=params["model"], model=params["model"],
prompt=prompt, prompt=prompt,
stream=True, stream=True,