mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
mistralai[patch]: 16k token batching logic embed (#17136)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import (
|
||||
@@ -16,9 +17,12 @@ from mistralai.constants import (
|
||||
ENDPOINT as DEFAULT_MISTRAL_ENDPOINT,
|
||||
)
|
||||
from mistralai.exceptions import MistralException
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_TOKENS = 16_000
|
||||
|
||||
|
||||
class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
"""MistralAI embedding models.
|
||||
@@ -43,6 +47,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
max_retries: int = 5
|
||||
timeout: int = 120
|
||||
max_concurrent_requests: int = 64
|
||||
tokenizer: Tokenizer = Field(default=None)
|
||||
|
||||
model: str = "mistral-embed"
|
||||
|
||||
@@ -72,8 +77,33 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
timeout=values["timeout"],
|
||||
max_concurrent_requests=values["max_concurrent_requests"],
|
||||
)
|
||||
if values["tokenizer"] is None:
|
||||
values["tokenizer"] = Tokenizer.from_pretrained(
|
||||
"mistralai/Mixtral-8x7B-v0.1"
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
|
||||
"""Split a list of texts into batches of less than 16k tokens
|
||||
for Mistral API."""
|
||||
batch: List[str] = []
|
||||
batch_tokens = 0
|
||||
|
||||
text_token_lengths = [
|
||||
len(encoded) for encoded in self.tokenizer.encode_batch(texts)
|
||||
]
|
||||
|
||||
for text, text_tokens in zip(texts, text_token_lengths):
|
||||
if batch_tokens + text_tokens > MAX_TOKENS:
|
||||
yield batch
|
||||
batch = [text]
|
||||
batch_tokens = text_tokens
|
||||
else:
|
||||
batch.append(text)
|
||||
batch_tokens += text_tokens
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of document texts.
|
||||
|
||||
@@ -84,13 +114,17 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
try:
|
||||
embeddings_batch_response = self.client.embeddings(
|
||||
model=self.model,
|
||||
input=texts,
|
||||
batch_responses = (
|
||||
self.client.embeddings(
|
||||
model=self.model,
|
||||
input=batch,
|
||||
)
|
||||
for batch in self._get_batches(texts)
|
||||
)
|
||||
return [
|
||||
list(map(float, embedding_obj.embedding))
|
||||
for embedding_obj in embeddings_batch_response.data
|
||||
for response in batch_responses
|
||||
for embedding_obj in response.data
|
||||
]
|
||||
except MistralException as e:
|
||||
logger.error(f"An error occurred with MistralAI: {e}")
|
||||
@@ -106,13 +140,19 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
try:
|
||||
embeddings_batch_response = await self.async_client.embeddings(
|
||||
model=self.model,
|
||||
input=texts,
|
||||
batch_responses = await asyncio.gather(
|
||||
*[
|
||||
self.async_client.embeddings(
|
||||
model=self.model,
|
||||
input=batch,
|
||||
)
|
||||
for batch in self._get_batches(texts)
|
||||
]
|
||||
)
|
||||
return [
|
||||
list(map(float, embedding_obj.embedding))
|
||||
for embedding_obj in embeddings_batch_response.data
|
||||
for response in batch_responses
|
||||
for embedding_obj in response.data
|
||||
]
|
||||
except MistralException as e:
|
||||
logger.error(f"An error occurred with MistralAI: {e}")
|
||||
|
||||
Reference in New Issue
Block a user