mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 03:59:42 +00:00
ollama: allow base_url, headers, and auth to be passed (#25078)
This commit is contained in:
@@ -35,6 +35,7 @@ from langchain_core.messages import (
|
|||||||
from langchain_core.messages.ai import UsageMetadata
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
from langchain_core.messages.tool import tool_call
|
from langchain_core.messages.tool import tool_call
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
@@ -322,6 +323,21 @@ class ChatOllama(BaseChatModel):
|
|||||||
base_url: Optional[str] = None
|
base_url: Optional[str] = None
|
||||||
"""Base url the model is hosted under."""
|
"""Base url the model is hosted under."""
|
||||||
|
|
||||||
|
client_kwargs: Optional[dict] = {}
|
||||||
|
"""Additional kwargs to pass to the httpx Client.
|
||||||
|
For a full list of the params, see [this link](https://pydoc.dev/httpx/latest/httpx.Client.html)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_client: Client = Field(default=None)
|
||||||
|
"""
|
||||||
|
The client to use for making requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_async_client: AsyncClient = Field(default=None)
|
||||||
|
"""
|
||||||
|
The async client to use for making requests.
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
"""Get the default parameters for calling Ollama."""
|
"""Get the default parameters for calling Ollama."""
|
||||||
@@ -348,6 +364,15 @@ class ChatOllama(BaseChatModel):
|
|||||||
"keep_alive": self.keep_alive,
|
"keep_alive": self.keep_alive,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
|
def _set_clients(cls, values: dict) -> dict:
|
||||||
|
"""Set clients to use for ollama."""
|
||||||
|
values["_client"] = Client(host=values["base_url"], **values["client_kwargs"])
|
||||||
|
values["_async_client"] = AsyncClient(
|
||||||
|
host=values["base_url"], **values["client_kwargs"]
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
def _convert_messages_to_ollama_messages(
|
def _convert_messages_to_ollama_messages(
|
||||||
self, messages: List[BaseMessage]
|
self, messages: List[BaseMessage]
|
||||||
) -> Sequence[Message]:
|
) -> Sequence[Message]:
|
||||||
@@ -449,7 +474,7 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
params["options"]["stop"] = stop
|
params["options"]["stop"] = stop
|
||||||
if "tools" in kwargs:
|
if "tools" in kwargs:
|
||||||
yield await AsyncClient(host=self.base_url).chat(
|
yield await self._async_client.chat(
|
||||||
model=params["model"],
|
model=params["model"],
|
||||||
messages=ollama_messages,
|
messages=ollama_messages,
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -459,7 +484,7 @@ class ChatOllama(BaseChatModel):
|
|||||||
tools=kwargs["tools"],
|
tools=kwargs["tools"],
|
||||||
) # type:ignore
|
) # type:ignore
|
||||||
else:
|
else:
|
||||||
async for part in await AsyncClient(host=self.base_url).chat(
|
async for part in await self._async_client.chat(
|
||||||
model=params["model"],
|
model=params["model"],
|
||||||
messages=ollama_messages,
|
messages=ollama_messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
@@ -487,7 +512,7 @@ class ChatOllama(BaseChatModel):
|
|||||||
|
|
||||||
params["options"]["stop"] = stop
|
params["options"]["stop"] = stop
|
||||||
if "tools" in kwargs:
|
if "tools" in kwargs:
|
||||||
yield Client(host=self.base_url).chat(
|
yield self._client.chat(
|
||||||
model=params["model"],
|
model=params["model"],
|
||||||
messages=ollama_messages,
|
messages=ollama_messages,
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -497,7 +522,7 @@ class ChatOllama(BaseChatModel):
|
|||||||
tools=kwargs["tools"],
|
tools=kwargs["tools"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield from Client(host=self.base_url).chat(
|
yield from self._client.chat(
|
||||||
model=params["model"],
|
model=params["model"],
|
||||||
messages=ollama_messages,
|
messages=ollama_messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@@ -1,9 +1,11 @@
|
|||||||
from typing import List
|
from typing import (
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
|
||||||
import ollama
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra
|
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient, Client
|
||||||
|
|
||||||
|
|
||||||
class OllamaEmbeddings(BaseModel, Embeddings):
|
class OllamaEmbeddings(BaseModel, Embeddings):
|
||||||
@@ -21,14 +23,41 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
model: str
|
model: str
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
"""Base url the model is hosted under."""
|
||||||
|
|
||||||
|
client_kwargs: Optional[dict] = {}
|
||||||
|
"""Additional kwargs to pass to the httpx Client.
|
||||||
|
For a full list of the params, see [this link](https://pydoc.dev/httpx/latest/httpx.Client.html)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_client: Client = Field(default=None)
|
||||||
|
"""
|
||||||
|
The client to use for making requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_async_client: AsyncClient = Field(default=None)
|
||||||
|
"""
|
||||||
|
The async client to use for making requests.
|
||||||
|
"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
|
def _set_clients(cls, values: dict) -> dict:
|
||||||
|
"""Set clients to use for ollama."""
|
||||||
|
values["_client"] = Client(host=values["base_url"], **values["client_kwargs"])
|
||||||
|
values["_async_client"] = AsyncClient(
|
||||||
|
host=values["base_url"], **values["client_kwargs"]
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Embed search docs."""
|
"""Embed search docs."""
|
||||||
embedded_docs = ollama.embed(self.model, texts)["embeddings"]
|
embedded_docs = self._client.embed(self.model, texts)["embeddings"]
|
||||||
return embedded_docs
|
return embedded_docs
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
@@ -37,7 +66,9 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Embed search docs."""
|
"""Embed search docs."""
|
||||||
embedded_docs = (await AsyncClient().embed(self.model, texts))["embeddings"]
|
embedded_docs = (await self._async_client.embed(self.model, texts))[
|
||||||
|
"embeddings"
|
||||||
|
]
|
||||||
return embedded_docs
|
return embedded_docs
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
@@ -12,14 +12,14 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import ollama
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models import BaseLLM
|
from langchain_core.language_models import BaseLLM
|
||||||
from langchain_core.outputs import GenerationChunk, LLMResult
|
from langchain_core.outputs import GenerationChunk, LLMResult
|
||||||
from ollama import AsyncClient, Options
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
|
from ollama import AsyncClient, Client, Options
|
||||||
|
|
||||||
|
|
||||||
class OllamaLLM(BaseLLM):
|
class OllamaLLM(BaseLLM):
|
||||||
@@ -107,6 +107,24 @@ class OllamaLLM(BaseLLM):
|
|||||||
keep_alive: Optional[Union[int, str]] = None
|
keep_alive: Optional[Union[int, str]] = None
|
||||||
"""How long the model will stay loaded into memory."""
|
"""How long the model will stay loaded into memory."""
|
||||||
|
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
"""Base url the model is hosted under."""
|
||||||
|
|
||||||
|
client_kwargs: Optional[dict] = {}
|
||||||
|
"""Additional kwargs to pass to the httpx Client.
|
||||||
|
For a full list of the params, see [this link](https://pydoc.dev/httpx/latest/httpx.Client.html)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_client: Client = Field(default=None)
|
||||||
|
"""
|
||||||
|
The client to use for making requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_async_client: AsyncClient = Field(default=None)
|
||||||
|
"""
|
||||||
|
The async client to use for making requests.
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
"""Get the default parameters for calling Ollama."""
|
"""Get the default parameters for calling Ollama."""
|
||||||
@@ -137,6 +155,15 @@ class OllamaLLM(BaseLLM):
|
|||||||
"""Return type of LLM."""
|
"""Return type of LLM."""
|
||||||
return "ollama-llm"
|
return "ollama-llm"
|
||||||
|
|
||||||
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
|
def _set_clients(cls, values: dict) -> dict:
|
||||||
|
"""Set clients to use for ollama."""
|
||||||
|
values["_client"] = Client(host=values["base_url"], **values["client_kwargs"])
|
||||||
|
values["_async_client"] = AsyncClient(
|
||||||
|
host=values["base_url"], **values["client_kwargs"]
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
async def _acreate_generate_stream(
|
async def _acreate_generate_stream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -155,7 +182,7 @@ class OllamaLLM(BaseLLM):
|
|||||||
params[key] = kwargs[key]
|
params[key] = kwargs[key]
|
||||||
|
|
||||||
params["options"]["stop"] = stop
|
params["options"]["stop"] = stop
|
||||||
async for part in await AsyncClient().generate(
|
async for part in await self._async_client.generate(
|
||||||
model=params["model"],
|
model=params["model"],
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=True,
|
stream=True,
|
||||||
@@ -183,7 +210,7 @@ class OllamaLLM(BaseLLM):
|
|||||||
params[key] = kwargs[key]
|
params[key] = kwargs[key]
|
||||||
|
|
||||||
params["options"]["stop"] = stop
|
params["options"]["stop"] = stop
|
||||||
yield from ollama.generate(
|
yield from self._client.generate(
|
||||||
model=params["model"],
|
model=params["model"],
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
Reference in New Issue
Block a user