mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
mistral: catch GatedRepoError, release 0.1.3 (#20802)
https://github.com/langchain-ai/langchain/issues/20618 --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import httpx
|
||||
@@ -19,6 +20,13 @@ logger = logging.getLogger(__name__)
|
||||
MAX_TOKENS = 16_000
|
||||
|
||||
|
||||
class DummyTokenizer:
|
||||
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
|
||||
|
||||
def encode_batch(self, texts: List[str]) -> List[List[str]]:
|
||||
return [list(text) for text in texts]
|
||||
|
||||
|
||||
class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
"""MistralAI embedding models.
|
||||
|
||||
@@ -83,9 +91,18 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
timeout=values["timeout"],
|
||||
)
|
||||
if values["tokenizer"] is None:
|
||||
values["tokenizer"] = Tokenizer.from_pretrained(
|
||||
"mistralai/Mixtral-8x7B-v0.1"
|
||||
)
|
||||
try:
|
||||
values["tokenizer"] = Tokenizer.from_pretrained(
|
||||
"mistralai/Mixtral-8x7B-v0.1"
|
||||
)
|
||||
except IOError: # huggingface_hub GatedRepoError
|
||||
warnings.warn(
|
||||
"Could not download mistral tokenizer from Huggingface for "
|
||||
"calculating batch sizes. Set a Huggingface token via the "
|
||||
"HF_TOKEN environment variable to download the real tokenizer. "
|
||||
"Falling back to a dummy tokenizer that uses `len()`."
|
||||
)
|
||||
values["tokenizer"] = DummyTokenizer()
|
||||
return values
|
||||
|
||||
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
|
||||
@@ -100,7 +117,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
for text, text_tokens in zip(texts, text_token_lengths):
|
||||
if batch_tokens + text_tokens > MAX_TOKENS:
|
||||
yield batch
|
||||
if len(batch) > 0:
|
||||
# edge case where first batch exceeds max tokens
|
||||
# should not yield an empty batch.
|
||||
yield batch
|
||||
batch = [text]
|
||||
batch_tokens = text_tokens
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user