mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 22:29:51 +00:00
langchain_ollama: Support keep_alive in embeddings (#30251)
- Description: Adds support for keep_alive in Ollama Embeddings see https://github.com/ollama/ollama/issues/6401. Builds on top of of https://github.com/langchain-ai/langchain/pull/29296. I have this use case where I want to keep the embeddings model in cpu forever. - Dependencies: no deps are being introduced. - Issue: haven't created an issue yet.
This commit is contained in:
parent
65a8f30729
commit
ac22cde130
@ -164,6 +164,11 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
"""The number of GPUs to use. On macOS it defaults to 1 to
|
||||
enable metal support, 0 to disable."""
|
||||
|
||||
keep_alive: Optional[int] = None
|
||||
"""controls how long the model will stay loaded into memory
|
||||
following the request (default: 5m)
|
||||
"""
|
||||
|
||||
num_thread: Optional[int] = None
|
||||
"""Sets the number of threads to use during computation.
|
||||
By default, Ollama will detect this for optimal performance.
|
||||
@ -235,7 +240,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
embedded_docs = self._client.embed(
|
||||
self.model, texts, options=self._default_params
|
||||
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
|
||||
)["embeddings"]
|
||||
return embedded_docs
|
||||
|
||||
@ -245,9 +250,11 @@ class OllamaEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
embedded_docs = (await self._async_client.embed(self.model, texts))[
|
||||
"embeddings"
|
||||
]
|
||||
embedded_docs = (
|
||||
await self._async_client.embed(
|
||||
self.model, texts, keep_alive=self.keep_alive
|
||||
)
|
||||
)["embeddings"]
|
||||
return embedded_docs
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
|
@ -5,4 +5,4 @@ from langchain_ollama.embeddings import OllamaEmbeddings
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
OllamaEmbeddings(model="llama3")
|
||||
OllamaEmbeddings(model="llama3", keep_alive=1)
|
||||
|
Loading…
Reference in New Issue
Block a user