mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 15:35:14 +00:00
community[minor]: llamafile embeddings support (#17976)
* **Description:** adds `LlamafileEmbeddings` class implementation for generating embeddings using [llamafile](https://github.com/Mozilla-Ocho/llamafile)-based models. Includes related unit tests and notebook showing example usage. * **Issue:** N/A * **Dependencies:** N/A
This commit is contained in:
@@ -57,6 +57,7 @@ from langchain_community.embeddings.jina import JinaEmbeddings
|
||||
from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
|
||||
from langchain_community.embeddings.laser import LaserEmbeddings
|
||||
from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings
|
||||
from langchain_community.embeddings.llamafile import LlamafileEmbeddings
|
||||
from langchain_community.embeddings.llm_rails import LLMRailsEmbeddings
|
||||
from langchain_community.embeddings.localai import LocalAIEmbeddings
|
||||
from langchain_community.embeddings.minimax import MiniMaxEmbeddings
|
||||
@@ -112,6 +113,7 @@ __all__ = [
|
||||
"JinaEmbeddings",
|
||||
"LaserEmbeddings",
|
||||
"LlamaCppEmbeddings",
|
||||
"LlamafileEmbeddings",
|
||||
"LLMRailsEmbeddings",
|
||||
"HuggingFaceHubEmbeddings",
|
||||
"MlflowEmbeddings",
|
||||
|
119
libs/community/langchain_community/embeddings/llamafile.py
Normal file
119
libs/community/langchain_community/embeddings/llamafile.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamafileEmbeddings(BaseModel, Embeddings):
|
||||
"""Llamafile lets you distribute and run large language models with a
|
||||
single file.
|
||||
|
||||
To get started, see: https://github.com/Mozilla-Ocho/llamafile
|
||||
|
||||
To use this class, you will need to first:
|
||||
|
||||
1. Download a llamafile.
|
||||
2. Make the downloaded file executable: `chmod +x path/to/model.llamafile`
|
||||
3. Start the llamafile in server mode with embeddings enabled:
|
||||
|
||||
`./path/to/model.llamafile --server --nobrowser --embedding`
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import LlamafileEmbeddings
|
||||
embedder = LlamafileEmbeddings()
|
||||
doc_embeddings = embedder.embed_documents(
|
||||
[
|
||||
"Alpha is the first letter of the Greek alphabet",
|
||||
"Beta is the second letter of the Greek alphabet",
|
||||
]
|
||||
)
|
||||
query_embedding = embedder.embed_query(
|
||||
"What is the second letter of the Greek alphabet"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
base_url: str = "http://localhost:8080"
|
||||
"""Base url where the llamafile server is listening."""
|
||||
|
||||
request_timeout: Optional[int] = None
|
||||
"""Timeout for server requests"""
|
||||
|
||||
def _embed(self, text: str) -> List[float]:
|
||||
try:
|
||||
response = requests.post(
|
||||
url=f"{self.base_url}/embedding",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"content": text,
|
||||
},
|
||||
timeout=self.request_timeout,
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise requests.exceptions.ConnectionError(
|
||||
f"Could not connect to Llamafile server. Please make sure "
|
||||
f"that a server is running at {self.base_url}."
|
||||
)
|
||||
|
||||
# Raise exception if we got a bad (non-200) response status code
|
||||
response.raise_for_status()
|
||||
|
||||
contents = response.json()
|
||||
if "embedding" not in contents:
|
||||
raise KeyError(
|
||||
"Unexpected output from /embedding endpoint, output dict "
|
||||
"missing 'embedding' key."
|
||||
)
|
||||
|
||||
embedding = contents["embedding"]
|
||||
|
||||
# Sanity check the embedding vector:
|
||||
# Prior to llamafile v0.6.2, if the server was not started with the
|
||||
# `--embedding` option, the embedding endpoint would always return a
|
||||
# 0-vector. See issue:
|
||||
# https://github.com/Mozilla-Ocho/llamafile/issues/243
|
||||
# So here we raise an exception if the vector sums to exactly 0.
|
||||
if sum(embedding) == 0.0:
|
||||
raise ValueError(
|
||||
"Embedding sums to 0, did you start the llamafile server with "
|
||||
"the `--embedding` option enabled?"
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed documents using a llamafile server running at `self.base_url`.
|
||||
llamafile server should be started in a separate process before invoking
|
||||
this method.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
doc_embeddings = []
|
||||
for text in texts:
|
||||
doc_embeddings.append(self._embed(text))
|
||||
return doc_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a query using a llamafile server running at `self.base_url`.
|
||||
llamafile server should be started in a separate process before invoking
|
||||
this method.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self._embed(text)
|
@@ -295,6 +295,12 @@ def _import_llamacpp() -> Type[BaseLLM]:
|
||||
return LlamaCpp
|
||||
|
||||
|
||||
def _import_llamafile() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.llamafile import Llamafile
|
||||
|
||||
return Llamafile
|
||||
|
||||
|
||||
def _import_manifest() -> Type[BaseLLM]:
|
||||
from langchain_community.llms.manifest import ManifestWrapper
|
||||
|
||||
|
@@ -17,6 +17,7 @@ EXPECTED_ALL = [
|
||||
"JinaEmbeddings",
|
||||
"LaserEmbeddings",
|
||||
"LlamaCppEmbeddings",
|
||||
"LlamafileEmbeddings",
|
||||
"LLMRailsEmbeddings",
|
||||
"HuggingFaceHubEmbeddings",
|
||||
"MlflowAIGatewayEmbeddings",
|
||||
|
67
libs/community/tests/unit_tests/embeddings/test_llamafile.py
Normal file
67
libs/community/tests/unit_tests/embeddings/test_llamafile.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from langchain_community.embeddings import LlamafileEmbeddings
|
||||
|
||||
|
||||
def mock_response() -> requests.Response:
|
||||
contents = json.dumps({"embedding": np.random.randn(512).tolist()})
|
||||
response = requests.Response()
|
||||
response.status_code = 200
|
||||
response._content = str.encode(contents)
|
||||
return response
|
||||
|
||||
|
||||
def test_embed_documents(monkeypatch: MonkeyPatch) -> None:
|
||||
"""
|
||||
Test basic functionality of the `embed_documents` method
|
||||
"""
|
||||
embedder = LlamafileEmbeddings(
|
||||
base_url="http://llamafile-host:8080",
|
||||
)
|
||||
|
||||
def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def]
|
||||
assert url == "http://llamafile-host:8080/embedding"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
# 'unknown' kwarg should be ignored
|
||||
assert json == {"content": "Test text"}
|
||||
# assert stream is False
|
||||
assert timeout is None
|
||||
return mock_response()
|
||||
|
||||
monkeypatch.setattr(requests, "post", mock_post)
|
||||
out = embedder.embed_documents(["Test text", "Test text"])
|
||||
assert isinstance(out, list)
|
||||
assert len(out) == 2
|
||||
for vec in out:
|
||||
assert len(vec) == 512
|
||||
|
||||
|
||||
def test_embed_query(monkeypatch: MonkeyPatch) -> None:
|
||||
"""
|
||||
Test basic functionality of the `embed_query` method
|
||||
"""
|
||||
embedder = LlamafileEmbeddings(
|
||||
base_url="http://llamafile-host:8080",
|
||||
)
|
||||
|
||||
def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def]
|
||||
assert url == "http://llamafile-host:8080/embedding"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
# 'unknown' kwarg should be ignored
|
||||
assert json == {"content": "Test text"}
|
||||
# assert stream is False
|
||||
assert timeout is None
|
||||
return mock_response()
|
||||
|
||||
monkeypatch.setattr(requests, "post", mock_post)
|
||||
out = embedder.embed_query("Test text")
|
||||
assert isinstance(out, list)
|
||||
assert len(out) == 512
|
Reference in New Issue
Block a user