Compare commits

...

1 Commits

Author SHA1 Message Date
Eugene Yurtsev
804670282e twelve labs partial 2024-08-13 16:47:18 -04:00
2 changed files with 145 additions and 2 deletions

View File

@@ -1,9 +1,13 @@
"""**Embeddings** interface."""
import abc
from abc import ABC, abstractmethod
from typing import List
from typing import List, Union, Literal, Optional, Any
from typing import TypedDict
from langchain_core.runnables.config import run_in_executor
from langchain_core.documents.base import BaseMedia, Document, Blob
from langchain_core.runnables import RunnableSerializable
from langchain_core.runnables.config import run_in_executor, RunnableConfig
class Embeddings(ABC):
@@ -77,3 +81,96 @@ class Embeddings(ABC):
Embedding.
"""
return await run_in_executor(None, self.embed_query, text)
# An input into the embedding model.
# The input can be a document, a media object, or a string.
# Whether it's supported depends on the specific embedding model.
EmbeddingInput = Union[BaseMedia, Document, str]
class BaseEmbedding(TypedDict):
"""Base embedding."""
vector: List[float]
scope: Literal["entire", "slice"]
class TextEmbedding(BaseEmbedding):
"""Text embedding."""
type: Literal["text"]
start: int
limit: int
class ImageEmbedding(TypedDict):
"""Designed for embedding 2-D images."""
type: Literal["image"]
class AudioEmbedding(TypedDict):
"""Audio embedding."""
type: Literal["audio"]
start: float
limit: float
class VideoEmbedding(TypedDict):
"""Video embedding"""
type: Literal["video"]
start: float
limit: float
Embedding = Union[TextEmbedding, ImageEmbedding, AudioEmbedding, VideoEmbedding]
class EmbeddingOutput(TypedDict):
"""The response of an embedding model."""
embeddings: List[Embedding]
def _standardize_embedding_input(input_: EmbeddingInput) -> Blob:
"""Convert an embedding input into a standardized Blob."""
if isinstance(input_, Blob):
return input_
elif isinstance(input_, Document):
return Blob(
id=input_.id,
metadata=input_.metadata,
data=input_.page_content,
mimetype_type="text/plain",
)
elif isinstance(input_, str): # This is a string of text to embed
return Blob(
metadata={},
data=input_,
mimetype_type="text/plain",
)
else:
raise NotImplementedError()
class EmbeddingModel(RunnableSerializable[EmbeddingInput, EmbeddingOutput]):
"""An embedding model."""
@abc.abstractmethod
def _embed(
self, input_: Blob, config: Optional[RunnableConfig], **kwargs: Any
) -> EmbeddingOutput:
"""Embed input."""
def invoke(
self,
input_: EmbeddingInput,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> EmbeddingOutput:
"""Embed input."""
blob = _standardize_embedding_input(input_)
return self._call_with_config(self._embed, blob, config=config, **kwargs)

View File

@@ -0,0 +1,46 @@
import requests
from typing import Any, Optional
from langchain_core.documents.base import Blob
from langchain_core.embeddings.embeddings import EmbeddingModel, EmbeddingOutput
from langchain_core.pydantic_v1 import SecretStr, Field, root_validator
from langchain_core.runnables import RunnableConfig
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from twelvelabs import TwelveLabs
class TwelveLabEmbedding(EmbeddingModel):
"""Twelve Lab embedding model."""
api_key: SecretStr
engine_name: str
sleep_ms: int
"""Number of milliseconds to sleep between requests."""
client: TwelveLabs = Field(default=None, exclude=True) #: :meta private:
@root_validator(pre=False, skip_on_failure=True)
def post_init(self, values):
from twelvelabs import TwelveLabs
values["client"] = TwelveLabs(
api_key=values["api_key"].get_secret_value(),
)
return values
def _embed(
self, input_: Blob, config: Optional[RunnableConfig], **kwargs: Any
) -> EmbeddingOutput:
"""Embed input."""
if input_.mimetype == "text/plain":
text = input_.as_string()
# "Marengo-retrieval-2.6"
text = input_.as_string()
response = self.client.embed.create(
engine_name=self.engine_name,
text=text,
)
return EmbeddingOutput(embeddings=[])