add embeddings using instructor large endpoint

This commit is contained in:
OctoML-Bassem
2023-06-11 22:58:34 -07:00
parent 1e774d303f
commit 31eee4e6ff
4 changed files with 132 additions and 22 deletions

View File

@@ -0,0 +1,82 @@
"""Module providing a wrapper around OctoAI Compute Service embedding models."""
from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from octoai import client
DEFAULT_EMBED_INSTRUCTION = "Represent this input: "
DEFAULT_QUERY_INSTRUCTION = "Represent the question for retrieving similar documents: "
class OctoAIEmbeddings(BaseModel, Embeddings):
"""
Wrapper around OctoAI Compute Service embedding models.
The environment variable ``OCTOAI_API_TOKEN`` should be set with your API token, or it can be passed
as a named parameter to the constructor.
"""
endpoint_url: Optional[str] = Field(
None, description="Endpoint URL to use.")
model_kwargs: Optional[dict] = Field(
None, description="Keyword arguments to pass to the model.")
octoai_api_token: Optional[str] = Field(
None, description="OCTOAI API Token")
embed_instruction: str = Field(
DEFAULT_EMBED_INSTRUCTION, description="Instruction to use for embedding documents.")
query_instruction: str = Field(
DEFAULT_QUERY_INSTRUCTION, description="Instruction to use for embedding query.")
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Ensure that the API key and python package exist in environment."""
values["octoai_api_token"] = get_from_dict_or_env(
values, "octoai_api_token", "OCTOAI_API_TOKEN")
values["endpoint_url"] = get_from_dict_or_env(
values, "endpoint_url", "ENDPOINT_URL")
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Return the identifying parameters."""
return {"endpoint_url": self.endpoint_url, "model_kwargs": self.model_kwargs or {}}
def _compute_embeddings(self, texts: List[str], instruction: str) -> List[List[float]]:
"""Common functionality for compute embeddings using a OctoAI instruct model."""
embeddings = []
octoai_client = client.Client(token=self.octoai_api_token)
for text in texts:
parameter_payload = {
"sentence": str([text]),# for item in text]),
"instruction": str([instruction]),# for item in text]),
"parameters": self.model_kwargs or {}
}
try:
resp_json = octoai_client.infer(
self.endpoint_url, parameter_payload)
embedding = resp_json["embeddings"]
except Exception as e:
raise ValueError(
f"Error raised by the inference endpoint: {e}") from e
embeddings.append(embedding)
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute document embeddings using an OctoAI instruct model."""
texts = list(map(lambda x: x.replace("\n", " "), texts))
return self._compute_embeddings(texts, self.embed_instruction)
def embed_query(self, text: str) -> List[float]:
"""Compute query embedding using an OctoAI instruct model."""
text = text.replace("\n", " ")
return self._compute_embeddings([text], self.embed_instruction)

View File

@@ -21,14 +21,20 @@ class OctoAIEndpoint(LLM):
Example:
.. code-block:: python
from langchain.llms import OctoAIEndpoint
endpoint_url = (
"https://endpoint_name-account_id.octoai.cloud"
)
endpoint = OctoAIEndpoint(
endpoint_url=endpoint_url,
octoai_api_token="octoai-api-key"
from langchain.llms.octoai_endpoint import OctoAIEndpoint
OctoAIEndpoint(
octoai_api_token="octoai-api-key",
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
model_kwargs={
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
"seed": None,
"stop": [],
},
)
"""
endpoint_url: Optional[str] = None
@@ -45,7 +51,7 @@ class OctoAIEndpoint(LLM):
extra = Extra.forbid
@root_validator()
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
octoai_api_token = get_from_dict_or_env(
@@ -90,26 +96,23 @@ class OctoAIEndpoint(LLM):
"""
_model_kwargs = self.model_kwargs or {}
# payload json
# Prepare the payload JSON
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
# HTTP headers for authorization
headers = {
"Authorization": f"Bearer {self.octoai_api_token}",
"Content-Type": "application/json",
}
# send request using octaoai sdk
try:
# Initialize the OctoAI client
octoai_client = client.Client(token=self.octoai_api_token)
# Send the request using the OctoAI client
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
text = resp_json["generated_text"]
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}") from e
# Handle any errors raised by the inference endpoint
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
if stop is not None:
# stop tokens when making calls to octoai.
# Apply stop tokens when making calls to OctoAI
text = enforce_stop_tokens(text, stop)
return text

View File

@@ -0,0 +1,23 @@
"""Test octoai embeddings."""
from langchain.embeddings.octoai_embeddings import (
OctoAIEmbeddings,
)
def test_octoai_embedding_documents() -> None:
"""Test octoai embeddings."""
documents = ["foo bar"]
embedding = OctoAIEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768
def test_octoai_embedding_query() -> None:
"""Test octoai embeddings."""
document = "foo bar"
embedding = OctoAIEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 1
assert len(output[0]) == 768

View File

@@ -10,12 +10,13 @@ from langchain.llms.octoai_endpoint import OctoAIEndpoint
from tests.integration_tests.llms.utils import assert_llm_equality
def test_octoai_endpoint_text_generation() -> None:
"""Test valid call to OctoAI text generation model."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
model_kwargs={
"max_new_tokens": 512,
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
@@ -32,8 +33,9 @@ def test_octoai_endpoint_text_generation() -> None:
def test_octoai_endpoint_call_error() -> None:
"""Test valid call to OctoAI that errors."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
model_kwargs={"max_new_tokens": -1})
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
model_kwargs={"max_new_tokens": -1},
)
with pytest.raises(ValueError):
llm("Which state is Los Angeles in?")
@@ -43,7 +45,7 @@ def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
model_kwargs={
"max_new_tokens": 512,
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,