From 63ddf0afb46187529be156fdbcaa1a4bacb2a4b8 Mon Sep 17 00:00:00 2001 From: Isaac Francisco <78627776+isahers1@users.noreply.github.com> Date: Mon, 5 Aug 2024 15:39:36 -0700 Subject: [PATCH] ollama: allow base_url, headers, and auth to be passed (#25078) --- .../ollama/langchain_ollama/chat_models.py | 33 ++++++++++++-- .../ollama/langchain_ollama/embeddings.py | 43 ++++++++++++++++--- libs/partners/ollama/langchain_ollama/llms.py | 35 +++++++++++++-- 3 files changed, 97 insertions(+), 14 deletions(-) diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index 04567e2f8c7..143761cf5f0 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -35,6 +35,7 @@ from langchain_core.messages import ( from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.tool import tool_call from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool @@ -322,6 +323,21 @@ class ChatOllama(BaseChatModel): base_url: Optional[str] = None """Base url the model is hosted under.""" + client_kwargs: Optional[dict] = {} + """Additional kwargs to pass to the httpx Client. + For a full list of the params, see [this link](https://pydoc.dev/httpx/latest/httpx.Client.html) + """ + + _client: Client = Field(default=None) + """ + The client to use for making requests. + """ + + _async_client: AsyncClient = Field(default=None) + """ + The async client to use for making requests. + """ + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling Ollama.""" @@ -348,6 +364,15 @@ class ChatOllama(BaseChatModel): "keep_alive": self.keep_alive, } + @root_validator(pre=False, skip_on_failure=True) + def _set_clients(cls, values: dict) -> dict: + """Set clients to use for ollama.""" + values["_client"] = Client(host=values["base_url"], **values["client_kwargs"]) + values["_async_client"] = AsyncClient( + host=values["base_url"], **values["client_kwargs"] + ) + return values + def _convert_messages_to_ollama_messages( self, messages: List[BaseMessage] ) -> Sequence[Message]: @@ -449,7 +474,7 @@ class ChatOllama(BaseChatModel): params["options"]["stop"] = stop if "tools" in kwargs: - yield await AsyncClient(host=self.base_url).chat( + yield await self._async_client.chat( model=params["model"], messages=ollama_messages, stream=False, @@ -459,7 +484,7 @@ class ChatOllama(BaseChatModel): tools=kwargs["tools"], ) # type:ignore else: - async for part in await AsyncClient(host=self.base_url).chat( + async for part in await self._async_client.chat( model=params["model"], messages=ollama_messages, stream=True, @@ -487,7 +512,7 @@ class ChatOllama(BaseChatModel): params["options"]["stop"] = stop if "tools" in kwargs: - yield Client(host=self.base_url).chat( + yield self._client.chat( model=params["model"], messages=ollama_messages, stream=False, @@ -497,7 +522,7 @@ class ChatOllama(BaseChatModel): tools=kwargs["tools"], ) else: - yield from Client(host=self.base_url).chat( + yield from self._client.chat( model=params["model"], messages=ollama_messages, stream=True, diff --git a/libs/partners/ollama/langchain_ollama/embeddings.py b/libs/partners/ollama/langchain_ollama/embeddings.py index 5528b134018..357878d2810 100644 --- a/libs/partners/ollama/langchain_ollama/embeddings.py +++ b/libs/partners/ollama/langchain_ollama/embeddings.py @@ -1,9 +1,11 @@ -from typing import List +from typing import ( + List, + Optional, +) -import ollama from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra -from ollama import AsyncClient +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from ollama import AsyncClient, Client class OllamaEmbeddings(BaseModel, Embeddings): @@ -21,14 +23,41 @@ class OllamaEmbeddings(BaseModel, Embeddings): model: str """Model name to use.""" + base_url: Optional[str] = None + """Base url the model is hosted under.""" + + client_kwargs: Optional[dict] = {} + """Additional kwargs to pass to the httpx Client. + For a full list of the params, see [this link](https://pydoc.dev/httpx/latest/httpx.Client.html) + """ + + _client: Client = Field(default=None) + """ + The client to use for making requests. + """ + + _async_client: AsyncClient = Field(default=None) + """ + The async client to use for making requests. + """ + class Config: """Configuration for this pydantic object.""" extra = Extra.forbid + @root_validator(pre=False, skip_on_failure=True) + def _set_clients(cls, values: dict) -> dict: + """Set clients to use for ollama.""" + values["_client"] = Client(host=values["base_url"], **values["client_kwargs"]) + values["_async_client"] = AsyncClient( + host=values["base_url"], **values["client_kwargs"] + ) + return values + def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed search docs.""" - embedded_docs = ollama.embed(self.model, texts)["embeddings"] + embedded_docs = self._client.embed(self.model, texts)["embeddings"] return embedded_docs def embed_query(self, text: str) -> List[float]: @@ -37,7 +66,9 @@ class OllamaEmbeddings(BaseModel, Embeddings): async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Embed search docs.""" - embedded_docs = (await AsyncClient().embed(self.model, texts))["embeddings"] + embedded_docs = (await self._async_client.embed(self.model, texts))[ + "embeddings" + ] return embedded_docs async def aembed_query(self, text: str) -> List[float]: diff --git a/libs/partners/ollama/langchain_ollama/llms.py b/libs/partners/ollama/langchain_ollama/llms.py index fd7b49dcd90..7ed7628bebc 100644 --- a/libs/partners/ollama/langchain_ollama/llms.py +++ b/libs/partners/ollama/langchain_ollama/llms.py @@ -12,14 +12,14 @@ from typing import ( Union, ) -import ollama from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models import BaseLLM from langchain_core.outputs import GenerationChunk, LLMResult -from ollama import AsyncClient, Options +from langchain_core.pydantic_v1 import Field, root_validator +from ollama import AsyncClient, Client, Options class OllamaLLM(BaseLLM): @@ -107,6 +107,24 @@ class OllamaLLM(BaseLLM): keep_alive: Optional[Union[int, str]] = None """How long the model will stay loaded into memory.""" + base_url: Optional[str] = None + """Base url the model is hosted under.""" + + client_kwargs: Optional[dict] = {} + """Additional kwargs to pass to the httpx Client. + For a full list of the params, see [this link](https://pydoc.dev/httpx/latest/httpx.Client.html) + """ + + _client: Client = Field(default=None) + """ + The client to use for making requests. + """ + + _async_client: AsyncClient = Field(default=None) + """ + The async client to use for making requests. + """ + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling Ollama.""" @@ -137,6 +155,15 @@ class OllamaLLM(BaseLLM): """Return type of LLM.""" return "ollama-llm" + @root_validator(pre=False, skip_on_failure=True) + def _set_clients(cls, values: dict) -> dict: + """Set clients to use for ollama.""" + values["_client"] = Client(host=values["base_url"], **values["client_kwargs"]) + values["_async_client"] = AsyncClient( + host=values["base_url"], **values["client_kwargs"] + ) + return values + async def _acreate_generate_stream( self, prompt: str, @@ -155,7 +182,7 @@ class OllamaLLM(BaseLLM): params[key] = kwargs[key] params["options"]["stop"] = stop - async for part in await AsyncClient().generate( + async for part in await self._async_client.generate( model=params["model"], prompt=prompt, stream=True, @@ -183,7 +210,7 @@ class OllamaLLM(BaseLLM): params[key] = kwargs[key] params["options"]["stop"] = stop - yield from ollama.generate( + yield from self._client.generate( model=params["model"], prompt=prompt, stream=True,