mistralai[patch]: 16k token batching logic embed (#17136)

This commit is contained in:
Erick Friis
2024-02-06 15:59:08 -08:00
committed by GitHub
parent 863f96b2e0
commit f881a3330c
4 changed files with 472 additions and 213 deletions

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import Dict, List, Optional
from typing import Dict, Iterable, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
@@ -16,9 +17,12 @@ from mistralai.constants import (
ENDPOINT as DEFAULT_MISTRAL_ENDPOINT,
)
from mistralai.exceptions import MistralException
from tokenizers import Tokenizer # type: ignore
logger = logging.getLogger(__name__)
MAX_TOKENS = 16_000
class MistralAIEmbeddings(BaseModel, Embeddings):
"""MistralAI embedding models.
@@ -43,6 +47,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
max_retries: int = 5
timeout: int = 120
max_concurrent_requests: int = 64
tokenizer: Tokenizer = Field(default=None)
model: str = "mistral-embed"
@@ -72,8 +77,33 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
timeout=values["timeout"],
max_concurrent_requests=values["max_concurrent_requests"],
)
if values["tokenizer"] is None:
values["tokenizer"] = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
return values
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
"""Split a list of texts into batches of less than 16k tokens
for Mistral API."""
batch: List[str] = []
batch_tokens = 0
text_token_lengths = [
len(encoded) for encoded in self.tokenizer.encode_batch(texts)
]
for text, text_tokens in zip(texts, text_token_lengths):
if batch_tokens + text_tokens > MAX_TOKENS:
yield batch
batch = [text]
batch_tokens = text_tokens
else:
batch.append(text)
batch_tokens += text_tokens
if batch:
yield batch
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of document texts.
@@ -84,13 +114,17 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
List of embeddings, one for each text.
"""
try:
embeddings_batch_response = self.client.embeddings(
model=self.model,
input=texts,
batch_responses = (
self.client.embeddings(
model=self.model,
input=batch,
)
for batch in self._get_batches(texts)
)
return [
list(map(float, embedding_obj.embedding))
for embedding_obj in embeddings_batch_response.data
for response in batch_responses
for embedding_obj in response.data
]
except MistralException as e:
logger.error(f"An error occurred with MistralAI: {e}")
@@ -106,13 +140,19 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
List of embeddings, one for each text.
"""
try:
embeddings_batch_response = await self.async_client.embeddings(
model=self.model,
input=texts,
batch_responses = await asyncio.gather(
*[
self.async_client.embeddings(
model=self.model,
input=batch,
)
for batch in self._get_batches(texts)
]
)
return [
list(map(float, embedding_obj.embedding))
for embedding_obj in embeddings_batch_response.data
for response in batch_responses
for embedding_obj in response.data
]
except MistralException as e:
logger.error(f"An error occurred with MistralAI: {e}")