mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 21:47:12 +00:00
johnsnowlabs embeddings support (#11271)
- **Description:** Introducing the [JohnSnowLabsEmbeddings](https://www.johnsnowlabs.com/) - **Dependencies:** johnsnowlabs - **Tag maintainer:** @C-K-Loan - **Twitter handle:** https://twitter.com/JohnSnowLabs https://twitter.com/ChristianKasimL --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
committed by
GitHub
parent
c08b622b2d
commit
a35445c65f
@@ -43,6 +43,7 @@ from langchain.embeddings.huggingface import (
|
||||
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
||||
from langchain.embeddings.javelin_ai_gateway import JavelinAIGatewayEmbeddings
|
||||
from langchain.embeddings.jina import JinaEmbeddings
|
||||
from langchain.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
|
||||
from langchain.embeddings.llamacpp import LlamaCppEmbeddings
|
||||
from langchain.embeddings.localai import LocalAIEmbeddings
|
||||
from langchain.embeddings.minimax import MiniMaxEmbeddings
|
||||
@@ -113,6 +114,7 @@ __all__ = [
|
||||
"JavelinAIGatewayEmbeddings",
|
||||
"OllamaEmbeddings",
|
||||
"QianfanEmbeddingsEndpoint",
|
||||
"JohnSnowLabsEmbeddings",
|
||||
]
|
||||
|
||||
|
||||
|
92
libs/langchain/langchain/embeddings/johnsnowlabs.py
Normal file
92
libs/langchain/langchain/embeddings/johnsnowlabs.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, Extra
|
||||
|
||||
|
||||
class JohnSnowLabsEmbeddings(BaseModel, Embeddings):
|
||||
"""JohnSnowLabs embedding models
|
||||
|
||||
To use, you should have the ``johnsnowlabs`` python package installed.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
|
||||
|
||||
embedding = JohnSnowLabsEmbeddings(model='embed_sentence.bert')
|
||||
output = embedding.embed_query("foo bar")
|
||||
"""
|
||||
|
||||
model: Any = "embed_sentence.bert"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any = "embed_sentence.bert",
|
||||
hardware_target: str = "cpu",
|
||||
**kwargs: Any
|
||||
):
|
||||
"""Initialize the johnsnowlabs model."""
|
||||
super().__init__(**kwargs)
|
||||
# 1) Check imports
|
||||
try:
|
||||
from johnsnowlabs import nlp
|
||||
from nlu.pipe.pipeline import NLUPipeline
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Could not import johnsnowlabs python package. "
|
||||
"Please install it with `pip install johnsnowlabs`."
|
||||
) from exc
|
||||
|
||||
# 2) Start a Spark Session
|
||||
try:
|
||||
os.environ["PYSPARK_PYTHON"] = sys.executable
|
||||
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
|
||||
nlp.start(hardware_target=hardware_target)
|
||||
except Exception as exc:
|
||||
raise Exception("Failure starting Spark Session") from exc
|
||||
|
||||
# 3) Load the model
|
||||
try:
|
||||
if isinstance(model, str):
|
||||
self.model = nlp.load(model)
|
||||
elif isinstance(model, NLUPipeline):
|
||||
self.model = model
|
||||
else:
|
||||
self.model = nlp.to_nlu_pipe(model)
|
||||
except Exception as exc:
|
||||
raise Exception("Failure loading model") from exc
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a JohnSnowLabs transformer model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
|
||||
df = self.model.predict(texts, output_level="document")
|
||||
emb_col = None
|
||||
for c in df.columns:
|
||||
if "embedding" in c:
|
||||
emb_col = c
|
||||
return [vec.tolist() for vec in df[emb_col].tolist()]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a JohnSnowLabs transformer model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
@@ -0,0 +1,20 @@
|
||||
"""Test johnsnowlabs embeddings."""
|
||||
|
||||
from langchain.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
|
||||
|
||||
|
||||
def test_johnsnowlabs_embed_document() -> None:
|
||||
"""Test johnsnowlabs embeddings."""
|
||||
documents = ["foo bar", "bar foo"]
|
||||
embedding = JohnSnowLabsEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 128
|
||||
|
||||
|
||||
def test_johnsnowlabs_embed_query() -> None:
|
||||
"""Test johnsnowlabs embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = JohnSnowLabsEmbeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 128
|
Reference in New Issue
Block a user