mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 22:56:05 +00:00
add embeddings using instructor large endpoint
This commit is contained in:
82
langchain/embeddings/octoai_embeddings.py
Normal file
82
langchain/embeddings/octoai_embeddings.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Module providing a wrapper around OctoAI Compute Service embedding models."""
|
||||
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from octoai import client
|
||||
|
||||
DEFAULT_EMBED_INSTRUCTION = "Represent this input: "
|
||||
DEFAULT_QUERY_INSTRUCTION = "Represent the question for retrieving similar documents: "
|
||||
|
||||
|
||||
class OctoAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
Wrapper around OctoAI Compute Service embedding models.
|
||||
|
||||
The environment variable ``OCTOAI_API_TOKEN`` should be set with your API token, or it can be passed
|
||||
as a named parameter to the constructor.
|
||||
"""
|
||||
endpoint_url: Optional[str] = Field(
|
||||
None, description="Endpoint URL to use.")
|
||||
model_kwargs: Optional[dict] = Field(
|
||||
None, description="Keyword arguments to pass to the model.")
|
||||
octoai_api_token: Optional[str] = Field(
|
||||
None, description="OCTOAI API Token")
|
||||
embed_instruction: str = Field(
|
||||
DEFAULT_EMBED_INSTRUCTION, description="Instruction to use for embedding documents.")
|
||||
query_instruction: str = Field(
|
||||
DEFAULT_QUERY_INSTRUCTION, description="Instruction to use for embedding query.")
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Ensure that the API key and python package exist in environment."""
|
||||
values["octoai_api_token"] = get_from_dict_or_env(
|
||||
values, "octoai_api_token", "OCTOAI_API_TOKEN")
|
||||
values["endpoint_url"] = get_from_dict_or_env(
|
||||
values, "endpoint_url", "ENDPOINT_URL")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Return the identifying parameters."""
|
||||
return {"endpoint_url": self.endpoint_url, "model_kwargs": self.model_kwargs or {}}
|
||||
|
||||
def _compute_embeddings(self, texts: List[str], instruction: str) -> List[List[float]]:
|
||||
"""Common functionality for compute embeddings using a OctoAI instruct model."""
|
||||
embeddings = []
|
||||
octoai_client = client.Client(token=self.octoai_api_token)
|
||||
|
||||
for text in texts:
|
||||
parameter_payload = {
|
||||
"sentence": str([text]),# for item in text]),
|
||||
"instruction": str([instruction]),# for item in text]),
|
||||
"parameters": self.model_kwargs or {}
|
||||
}
|
||||
|
||||
try:
|
||||
resp_json = octoai_client.infer(
|
||||
self.endpoint_url, parameter_payload)
|
||||
embedding = resp_json["embeddings"]
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error raised by the inference endpoint: {e}") from e
|
||||
|
||||
embeddings.append(embedding)
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute document embeddings using an OctoAI instruct model."""
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
return self._compute_embeddings(texts, self.embed_instruction)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embedding using an OctoAI instruct model."""
|
||||
text = text.replace("\n", " ")
|
||||
return self._compute_embeddings([text], self.embed_instruction)
|
||||
|
||||
@@ -21,14 +21,20 @@ class OctoAIEndpoint(LLM):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import OctoAIEndpoint
|
||||
endpoint_url = (
|
||||
"https://endpoint_name-account_id.octoai.cloud"
|
||||
)
|
||||
endpoint = OctoAIEndpoint(
|
||||
endpoint_url=endpoint_url,
|
||||
octoai_api_token="octoai-api-key"
|
||||
from langchain.llms.octoai_endpoint import OctoAIEndpoint
|
||||
OctoAIEndpoint(
|
||||
octoai_api_token="octoai-api-key",
|
||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
||||
model_kwargs={
|
||||
"max_new_tokens": 200,
|
||||
"temperature": 0.75,
|
||||
"top_p": 0.95,
|
||||
"repetition_penalty": 1,
|
||||
"seed": None,
|
||||
"stop": [],
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
endpoint_url: Optional[str] = None
|
||||
@@ -45,7 +51,7 @@ class OctoAIEndpoint(LLM):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
octoai_api_token = get_from_dict_or_env(
|
||||
@@ -90,26 +96,23 @@ class OctoAIEndpoint(LLM):
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
# payload json
|
||||
# Prepare the payload JSON
|
||||
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
||||
|
||||
# HTTP headers for authorization
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.octoai_api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# send request using octaoai sdk
|
||||
try:
|
||||
# Initialize the OctoAI client
|
||||
octoai_client = client.Client(token=self.octoai_api_token)
|
||||
|
||||
# Send the request using the OctoAI client
|
||||
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
|
||||
text = resp_json["generated_text"]
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}") from e
|
||||
# Handle any errors raised by the inference endpoint
|
||||
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
||||
|
||||
if stop is not None:
|
||||
# stop tokens when making calls to octoai.
|
||||
# Apply stop tokens when making calls to OctoAI
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
||||
|
||||
23
tests/integration_tests/embeddings/test_octoai_embeddings.py
Normal file
23
tests/integration_tests/embeddings/test_octoai_embeddings.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Test octoai embeddings."""
|
||||
|
||||
from langchain.embeddings.octoai_embeddings import (
|
||||
OctoAIEmbeddings,
|
||||
)
|
||||
|
||||
|
||||
def test_octoai_embedding_documents() -> None:
|
||||
"""Test octoai embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = OctoAIEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
|
||||
|
||||
def test_octoai_embedding_query() -> None:
|
||||
"""Test octoai embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = OctoAIEmbeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
@@ -10,12 +10,13 @@ from langchain.llms.octoai_endpoint import OctoAIEndpoint
|
||||
|
||||
from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
|
||||
|
||||
def test_octoai_endpoint_text_generation() -> None:
|
||||
"""Test valid call to OctoAI text generation model."""
|
||||
llm = OctoAIEndpoint(
|
||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
||||
model_kwargs={
|
||||
"max_new_tokens": 512,
|
||||
"max_new_tokens": 200,
|
||||
"temperature": 0.75,
|
||||
"top_p": 0.95,
|
||||
"repetition_penalty": 1,
|
||||
@@ -32,8 +33,9 @@ def test_octoai_endpoint_text_generation() -> None:
|
||||
def test_octoai_endpoint_call_error() -> None:
|
||||
"""Test valid call to OctoAI that errors."""
|
||||
llm = OctoAIEndpoint(
|
||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
||||
model_kwargs={"max_new_tokens": -1})
|
||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
||||
model_kwargs={"max_new_tokens": -1},
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
llm("Which state is Los Angeles in?")
|
||||
|
||||
@@ -43,7 +45,7 @@ def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
|
||||
llm = OctoAIEndpoint(
|
||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
||||
model_kwargs={
|
||||
"max_new_tokens": 512,
|
||||
"max_new_tokens": 200,
|
||||
"temperature": 0.75,
|
||||
"top_p": 0.95,
|
||||
"repetition_penalty": 1,
|
||||
|
||||
Reference in New Issue
Block a user