diff --git a/libs/partners/ollama/langchain_ollama/embeddings.py b/libs/partners/ollama/langchain_ollama/embeddings.py index 6105a3ba0da..35b2e095ac6 100644 --- a/libs/partners/ollama/langchain_ollama/embeddings.py +++ b/libs/partners/ollama/langchain_ollama/embeddings.py @@ -164,6 +164,11 @@ class OllamaEmbeddings(BaseModel, Embeddings): """The number of GPUs to use. On macOS it defaults to 1 to enable metal support, 0 to disable.""" + keep_alive: Optional[int] = None + """controls how long the model will stay loaded into memory + following the request (default: 5m) + """ + num_thread: Optional[int] = None """Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. @@ -235,7 +240,7 @@ class OllamaEmbeddings(BaseModel, Embeddings): def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed search docs.""" embedded_docs = self._client.embed( - self.model, texts, options=self._default_params + self.model, texts, options=self._default_params, keep_alive=self.keep_alive )["embeddings"] return embedded_docs @@ -245,9 +250,11 @@ class OllamaEmbeddings(BaseModel, Embeddings): async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Embed search docs.""" - embedded_docs = (await self._async_client.embed(self.model, texts))[ - "embeddings" - ] + embedded_docs = ( + await self._async_client.embed( + self.model, texts, keep_alive=self.keep_alive + ) + )["embeddings"] return embedded_docs async def aembed_query(self, text: str) -> List[float]: diff --git a/libs/partners/ollama/tests/unit_tests/test_embeddings.py b/libs/partners/ollama/tests/unit_tests/test_embeddings.py index b1db475ce2a..82b5679c35a 100644 --- a/libs/partners/ollama/tests/unit_tests/test_embeddings.py +++ b/libs/partners/ollama/tests/unit_tests/test_embeddings.py @@ -5,4 +5,4 @@ from langchain_ollama.embeddings import OllamaEmbeddings def test_initialization() -> None: """Test embedding model initialization.""" - OllamaEmbeddings(model="llama3") + OllamaEmbeddings(model="llama3", keep_alive=1)