langchain/libs/community/langchain_community/embeddings/infinity_local.py
Eugene Yurtsev bf5193bb99
community[patch]: Upgrade pydantic extra (#25185)
Upgrade to using a literal for specifying the extra which is the
recommended approach in pydantic 2.

This works correctly also in pydantic v1.

```python
from pydantic.v1 import BaseModel

class Foo(BaseModel, extra="forbid"):
    x: int

Foo(x=5, y=1)
```

And 


```python
from pydantic.v1 import BaseModel

class Foo(BaseModel):
    x: int

    class Config:
      extra = "forbid"

Foo(x=5, y=1)
```


## Enum -> literal using grit pattern:

```
engine marzano(0.1)
language python
or {
    `extra=Extra.allow` => `extra="allow"`,
    `extra=Extra.forbid` => `extra="forbid"`,
    `extra=Extra.ignore` => `extra="ignore"`
}
```

Resorted attributes in config and removed doc-string in case we will
need to deal with going back and forth between pydantic v1 and v2 during
the 0.3 release. (This will reduce merge conflicts.)


## Sort attributes in Config:

```
engine marzano(0.1)
language python


function sort($values) js {
    return $values.text.split(',').sort().join("\n");
}


class_definition($name, $body) as $C where {
    $name <: `Config`,
    $body <: block($statements),
    $values = [],
    $statements <: some bubble($values) assignment() as $A where {
        $values += $A
    },
    $body => sort($values),
}

```
2024-08-08 17:20:39 +00:00

157 lines
5.0 KiB
Python

"""written under MIT Licence, Michael Feil 2023."""
import asyncio
from logging import getLogger
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
__all__ = ["InfinityEmbeddingsLocal"]
logger = getLogger(__name__)
class InfinityEmbeddingsLocal(BaseModel, Embeddings):
"""Optimized Infinity embedding models.
https://github.com/michaelfeil/infinity
This class deploys a local Infinity instance to embed text.
The class requires async usage.
Infinity is a class to interact with Embedding Models on https://github.com/michaelfeil/infinity
Example:
.. code-block:: python
from langchain_community.embeddings import InfinityEmbeddingsLocal
async with InfinityEmbeddingsLocal(
model="BAAI/bge-small-en-v1.5",
revision=None,
device="cpu",
) as embedder:
embeddings = await engine.aembed_documents(["text1", "text2"])
"""
model: str
"Underlying model id from huggingface, e.g. BAAI/bge-small-en-v1.5"
revision: Optional[str] = None
"Model version, the commit hash from huggingface"
batch_size: int = 32
"Internal batch size for inference, e.g. 32"
device: str = "auto"
"Device to use for inference, e.g. 'cpu' or 'cuda', or 'mps'"
backend: str = "torch"
"Backend for inference, e.g. 'torch' (recommended for ROCm/Nvidia)"
" or 'optimum' for onnx/tensorrt"
model_warmup: bool = True
"Warmup the model with the max batch size."
engine: Any = None #: :meta private:
"""Infinity's AsyncEmbeddingEngine."""
# LLM call kwargs
class Config:
extra = "forbid"
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
from infinity_emb import AsyncEmbeddingEngine # type: ignore
except ImportError:
raise ImportError(
"Please install the "
"`pip install 'infinity_emb[optimum,torch]>=0.0.24'` "
"package to use the InfinityEmbeddingsLocal."
)
logger.debug(f"Using InfinityEmbeddingsLocal with kwargs {values}")
values["engine"] = AsyncEmbeddingEngine(
model_name_or_path=values["model"],
device=values["device"],
revision=values["revision"],
model_warmup=values["model_warmup"],
batch_size=values["batch_size"],
engine=values["backend"],
)
return values
async def __aenter__(self) -> None:
"""start the background worker.
recommended usage is with the async with statement.
async with InfinityEmbeddingsLocal(
model="BAAI/bge-small-en-v1.5",
revision=None,
device="cpu",
) as embedder:
embeddings = await engine.aembed_documents(["text1", "text2"])
"""
await self.engine.__aenter__()
async def __aexit__(self, *args: Any) -> None:
"""stop the background worker,
required to free references to the pytorch model."""
await self.engine.__aexit__(*args)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Async call out to Infinity's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
if not self.engine.running:
logger.warning(
"Starting Infinity engine on the fly. This is not recommended."
"Please start the engine before using it."
)
async with self:
# spawning threadpool for multithreaded encode, tokenization
embeddings, _ = await self.engine.embed(texts)
# stopping threadpool on exit
logger.warning("Stopped infinity engine after usage.")
else:
embeddings, _ = await self.engine.embed(texts)
return embeddings
async def aembed_query(self, text: str) -> List[float]:
"""Async call out to Infinity's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
embeddings = await self.aembed_documents([text])
return embeddings[0]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
This method is async only.
"""
logger.warning(
"This method is async only. "
"Please use the async version `await aembed_documents`."
)
return asyncio.run(self.aembed_documents(texts))
def embed_query(self, text: str) -> List[float]:
""" """
logger.warning(
"This method is async only."
" Please use the async version `await aembed_query`."
)
return asyncio.run(self.aembed_query(text))