mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 12:31:49 +00:00
Signed-off-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com> Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com> Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: ccurme <chester.curme@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com> Co-authored-by: ZhangShenao <15201440436@163.com> Co-authored-by: Friso H. Kingma <fhkingma@gmail.com> Co-authored-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Nuno Campos <nuno@langchain.dev> Co-authored-by: Morgante Pell <morgantep@google.com>
96 lines
3.4 KiB
Python
96 lines
3.4 KiB
Python
from typing import Any, Dict, List
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.utils.pydantic import get_fields
|
|
from pydantic import BaseModel, ConfigDict, model_validator
|
|
|
|
|
|
class OpenCLIPEmbeddings(BaseModel, Embeddings):
|
|
"""OpenCLIP Embeddings model."""
|
|
|
|
model: Any
|
|
preprocess: Any
|
|
tokenizer: Any
|
|
# Select model: https://github.com/mlfoundations/open_clip
|
|
model_name: str = "ViT-H-14"
|
|
checkpoint: str = "laion2b_s32b_b79k"
|
|
|
|
model_config = ConfigDict(protected_namespaces=())
|
|
|
|
@model_validator(mode="before")
|
|
@classmethod
|
|
def validate_environment(cls, values: Dict) -> Any:
|
|
"""Validate that open_clip and torch libraries are installed."""
|
|
try:
|
|
import open_clip
|
|
|
|
# Fall back to class defaults if not provided
|
|
model_name = values.get("model_name", get_fields(cls)["model_name"].default)
|
|
checkpoint = values.get("checkpoint", get_fields(cls)["checkpoint"].default)
|
|
|
|
# Load model
|
|
model, _, preprocess = open_clip.create_model_and_transforms(
|
|
model_name=model_name, pretrained=checkpoint
|
|
)
|
|
tokenizer = open_clip.get_tokenizer(model_name)
|
|
values["model"] = model
|
|
values["preprocess"] = preprocess
|
|
values["tokenizer"] = tokenizer
|
|
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please ensure both open_clip and torch libraries are installed. "
|
|
"pip install open_clip_torch torch"
|
|
)
|
|
return values
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
text_features = []
|
|
for text in texts:
|
|
# Tokenize the text
|
|
tokenized_text = self.tokenizer(text)
|
|
|
|
# Encode the text to get the embeddings
|
|
embeddings_tensor = self.model.encode_text(tokenized_text)
|
|
|
|
# Normalize the embeddings
|
|
norm = embeddings_tensor.norm(p=2, dim=1, keepdim=True)
|
|
normalized_embeddings_tensor = embeddings_tensor.div(norm)
|
|
|
|
# Convert normalized tensor to list and add to the text_features list
|
|
embeddings_list = normalized_embeddings_tensor.squeeze(0).tolist()
|
|
text_features.append(embeddings_list)
|
|
|
|
return text_features
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
return self.embed_documents([text])[0]
|
|
|
|
def embed_image(self, uris: List[str]) -> List[List[float]]:
|
|
try:
|
|
from PIL import Image as _PILImage
|
|
except ImportError:
|
|
raise ImportError("Please install the PIL library: pip install pillow")
|
|
|
|
# Open images directly as PIL images
|
|
pil_images = [_PILImage.open(uri) for uri in uris]
|
|
|
|
image_features = []
|
|
for pil_image in pil_images:
|
|
# Preprocess the image for the model
|
|
preprocessed_image = self.preprocess(pil_image).unsqueeze(0)
|
|
|
|
# Encode the image to get the embeddings
|
|
embeddings_tensor = self.model.encode_image(preprocessed_image)
|
|
|
|
# Normalize the embeddings tensor
|
|
norm = embeddings_tensor.norm(p=2, dim=1, keepdim=True)
|
|
normalized_embeddings_tensor = embeddings_tensor.div(norm)
|
|
|
|
# Convert tensor to list and add to the image_features list
|
|
embeddings_list = normalized_embeddings_tensor.squeeze(0).tolist()
|
|
|
|
image_features.append(embeddings_list)
|
|
|
|
return image_features
|