From 31eee4e6ffa79356a62fa66bd18fc4a4a5119793 Mon Sep 17 00:00:00 2001 From: OctoML-Bassem Date: Sun, 11 Jun 2023 22:58:34 -0700 Subject: [PATCH] add embeddings using instructor large endpoint --- langchain/embeddings/octoai_embeddings.py | 82 +++++++++++++++++++ langchain/llms/octoai_endpoint.py | 39 +++++---- .../embeddings/test_octoai_embeddings.py | 23 ++++++ .../llms/test_octoai_endpoint.py | 10 ++- 4 files changed, 132 insertions(+), 22 deletions(-) create mode 100644 langchain/embeddings/octoai_embeddings.py create mode 100644 tests/integration_tests/embeddings/test_octoai_embeddings.py diff --git a/langchain/embeddings/octoai_embeddings.py b/langchain/embeddings/octoai_embeddings.py new file mode 100644 index 00000000000..ced648eb205 --- /dev/null +++ b/langchain/embeddings/octoai_embeddings.py @@ -0,0 +1,82 @@ +"""Module providing a wrapper around OctoAI Compute Service embedding models.""" + +from typing import Any, Dict, List, Mapping, Optional +from pydantic import BaseModel, Extra, Field, root_validator +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env +from octoai import client + +DEFAULT_EMBED_INSTRUCTION = "Represent this input: " +DEFAULT_QUERY_INSTRUCTION = "Represent the question for retrieving similar documents: " + + +class OctoAIEmbeddings(BaseModel, Embeddings): + """ + Wrapper around OctoAI Compute Service embedding models. + + The environment variable ``OCTOAI_API_TOKEN`` should be set with your API token, or it can be passed + as a named parameter to the constructor. + """ + endpoint_url: Optional[str] = Field( + None, description="Endpoint URL to use.") + model_kwargs: Optional[dict] = Field( + None, description="Keyword arguments to pass to the model.") + octoai_api_token: Optional[str] = Field( + None, description="OCTOAI API Token") + embed_instruction: str = Field( + DEFAULT_EMBED_INSTRUCTION, description="Instruction to use for embedding documents.") + query_instruction: str = Field( + DEFAULT_QUERY_INSTRUCTION, description="Instruction to use for embedding query.") + + class Config: + """Configuration for this pydantic object.""" + extra = Extra.forbid + + @root_validator(allow_reuse=True) + def validate_environment(cls, values: Dict) -> Dict: + """Ensure that the API key and python package exist in environment.""" + values["octoai_api_token"] = get_from_dict_or_env( + values, "octoai_api_token", "OCTOAI_API_TOKEN") + values["endpoint_url"] = get_from_dict_or_env( + values, "endpoint_url", "ENDPOINT_URL") + return values + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Return the identifying parameters.""" + return {"endpoint_url": self.endpoint_url, "model_kwargs": self.model_kwargs or {}} + + def _compute_embeddings(self, texts: List[str], instruction: str) -> List[List[float]]: + """Common functionality for compute embeddings using a OctoAI instruct model.""" + embeddings = [] + octoai_client = client.Client(token=self.octoai_api_token) + + for text in texts: + parameter_payload = { + "sentence": str([text]),# for item in text]), + "instruction": str([instruction]),# for item in text]), + "parameters": self.model_kwargs or {} + } + + try: + resp_json = octoai_client.infer( + self.endpoint_url, parameter_payload) + embedding = resp_json["embeddings"] + except Exception as e: + raise ValueError( + f"Error raised by the inference endpoint: {e}") from e + + embeddings.append(embedding) + + return embeddings + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute document embeddings using an OctoAI instruct model.""" + texts = list(map(lambda x: x.replace("\n", " "), texts)) + return self._compute_embeddings(texts, self.embed_instruction) + + def embed_query(self, text: str) -> List[float]: + """Compute query embedding using an OctoAI instruct model.""" + text = text.replace("\n", " ") + return self._compute_embeddings([text], self.embed_instruction) + diff --git a/langchain/llms/octoai_endpoint.py b/langchain/llms/octoai_endpoint.py index 4da10598d65..fad9b5bb393 100644 --- a/langchain/llms/octoai_endpoint.py +++ b/langchain/llms/octoai_endpoint.py @@ -21,14 +21,20 @@ class OctoAIEndpoint(LLM): Example: .. code-block:: python - from langchain.llms import OctoAIEndpoint - endpoint_url = ( - "https://endpoint_name-account_id.octoai.cloud" - ) - endpoint = OctoAIEndpoint( - endpoint_url=endpoint_url, - octoai_api_token="octoai-api-key" + from langchain.llms.octoai_endpoint import OctoAIEndpoint + OctoAIEndpoint( + octoai_api_token="octoai-api-key", + endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", + model_kwargs={ + "max_new_tokens": 200, + "temperature": 0.75, + "top_p": 0.95, + "repetition_penalty": 1, + "seed": None, + "stop": [], + }, ) + """ endpoint_url: Optional[str] = None @@ -45,7 +51,7 @@ class OctoAIEndpoint(LLM): extra = Extra.forbid - @root_validator() + @root_validator(allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" octoai_api_token = get_from_dict_or_env( @@ -90,26 +96,23 @@ class OctoAIEndpoint(LLM): """ _model_kwargs = self.model_kwargs or {} - # payload json + # Prepare the payload JSON parameter_payload = {"inputs": prompt, "parameters": _model_kwargs} - # HTTP headers for authorization - headers = { - "Authorization": f"Bearer {self.octoai_api_token}", - "Content-Type": "application/json", - } - - # send request using octaoai sdk try: + # Initialize the OctoAI client octoai_client = client.Client(token=self.octoai_api_token) + + # Send the request using the OctoAI client resp_json = octoai_client.infer(self.endpoint_url, parameter_payload) text = resp_json["generated_text"] except Exception as e: - raise ValueError(f"Error raised by inference endpoint: {e}") from e + # Handle any errors raised by the inference endpoint + raise ValueError(f"Error raised by the inference endpoint: {e}") from e if stop is not None: - # stop tokens when making calls to octoai. + # Apply stop tokens when making calls to OctoAI text = enforce_stop_tokens(text, stop) return text diff --git a/tests/integration_tests/embeddings/test_octoai_embeddings.py b/tests/integration_tests/embeddings/test_octoai_embeddings.py new file mode 100644 index 00000000000..8d843d2958a --- /dev/null +++ b/tests/integration_tests/embeddings/test_octoai_embeddings.py @@ -0,0 +1,23 @@ +"""Test octoai embeddings.""" + +from langchain.embeddings.octoai_embeddings import ( + OctoAIEmbeddings, +) + + +def test_octoai_embedding_documents() -> None: + """Test octoai embeddings.""" + documents = ["foo bar"] + embedding = OctoAIEmbeddings() + output = embedding.embed_documents(documents) + assert len(output) == 1 + assert len(output[0]) == 768 + + +def test_octoai_embedding_query() -> None: + """Test octoai embeddings.""" + document = "foo bar" + embedding = OctoAIEmbeddings() + output = embedding.embed_query(document) + assert len(output) == 1 + assert len(output[0]) == 768 \ No newline at end of file diff --git a/tests/integration_tests/llms/test_octoai_endpoint.py b/tests/integration_tests/llms/test_octoai_endpoint.py index 35385e5f4b9..ee6b803c4b3 100644 --- a/tests/integration_tests/llms/test_octoai_endpoint.py +++ b/tests/integration_tests/llms/test_octoai_endpoint.py @@ -10,12 +10,13 @@ from langchain.llms.octoai_endpoint import OctoAIEndpoint from tests.integration_tests.llms.utils import assert_llm_equality + def test_octoai_endpoint_text_generation() -> None: """Test valid call to OctoAI text generation model.""" llm = OctoAIEndpoint( endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", model_kwargs={ - "max_new_tokens": 512, + "max_new_tokens": 200, "temperature": 0.75, "top_p": 0.95, "repetition_penalty": 1, @@ -32,8 +33,9 @@ def test_octoai_endpoint_text_generation() -> None: def test_octoai_endpoint_call_error() -> None: """Test valid call to OctoAI that errors.""" llm = OctoAIEndpoint( - endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", - model_kwargs={"max_new_tokens": -1}) + endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", + model_kwargs={"max_new_tokens": -1}, + ) with pytest.raises(ValueError): llm("Which state is Los Angeles in?") @@ -43,7 +45,7 @@ def test_saving_loading_endpoint_llm(tmp_path: Path) -> None: llm = OctoAIEndpoint( endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", model_kwargs={ - "max_new_tokens": 512, + "max_new_tokens": 200, "temperature": 0.75, "top_p": 0.95, "repetition_penalty": 1,