community[patch]: update OctoAI endpoint to subclass BaseOpenAI (#19757)

This PR updates OctoAIEndpoint LLM to subclass BaseOpenAI as OctoAI is
an OpenAI-compatible service. The documentation and tests have also been
updated.
This commit is contained in:
Sevin F. Varoglu 2024-04-17 03:32:20 +03:00 committed by GitHub
parent 0c95ddbcd8
commit 54d388d898
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 107 additions and 214 deletions

View File

@ -18,7 +18,7 @@
" \n", " \n",
"2. Paste your API key in in the code cell below.\n", "2. Paste your API key in in the code cell below.\n",
"\n", "\n",
"Note: If you want to use a different LLM model, you can containerize the model and make a custom OctoAI endpoint yourself, by following [Build a Container from Python](https://octo.ai/docs/bring-your-own-model/advanced-build-a-container-from-scratch-in-python) and [Create a Custom Endpoint from a Container](https://octo.ai/docs/bring-your-own-model/create-custom-endpoints-from-a-container/create-custom-endpoints-from-a-container) and then update your Endpoint URL in the code cell below.\n" "Note: If you want to use a different LLM model, you can containerize the model and make a custom OctoAI endpoint yourself, by following [Build a Container from Python](https://octo.ai/docs/bring-your-own-model/advanced-build-a-container-from-scratch-in-python) and [Create a Custom Endpoint from a Container](https://octo.ai/docs/bring-your-own-model/create-custom-endpoints-from-a-container/create-custom-endpoints-from-a-container) and then updating your `OCTOAI_API_BASE` environment variable.\n"
] ]
}, },
{ {
@ -29,8 +29,7 @@
"source": [ "source": [
"import os\n", "import os\n",
"\n", "\n",
"os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\"\n", "os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\""
"os.environ[\"ENDPOINT_URL\"] = \"https://text.octoai.run/v1/chat/completions\""
] ]
}, },
{ {
@ -68,44 +67,33 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"llm = OctoAIEndpoint(\n", "llm = OctoAIEndpoint(\n",
" model_kwargs={\n", " model=\"llama-2-13b-chat-fp16\",\n",
" \"model\": \"llama-2-13b-chat-fp16\",\n", " max_tokens=200,\n",
" \"max_tokens\": 128,\n", " presence_penalty=0,\n",
" \"presence_penalty\": 0,\n", " temperature=0.1,\n",
" \"temperature\": 0.1,\n", " top_p=0.9,\n",
" \"top_p\": 0.9,\n",
" \"messages\": [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a helpful assistant. Keep your responses limited to one short paragraph if possible.\",\n",
" },\n",
" ],\n",
" },\n",
")" ")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
" Sure thing! Here's my response:\n",
"\n",
"Leonardo da Vinci was a true Renaissance man - an Italian polymath who excelled in various fields, including painting, sculpture, engineering, mathematics, anatomy, and geology. He is widely considered one of the greatest painters of all time, and his inventive and innovative works continue to inspire and influence artists and thinkers to this day. Some of his most famous works include the Mona Lisa, The Last Supper, and Vitruvian Man. \n"
]
}
],
"source": [ "source": [
"question = \"Who was leonardo davinci?\"\n", "question = \"Who was Leonardo da Vinci?\"\n",
"\n", "\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n", "llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"\n", "\n",
"print(llm_chain.run(question))" "print(llm_chain.run(question))"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Leonardo da Vinci was a true Renaissance man. He was born in 1452 in Vinci, Italy and was known for his work in various fields, including art, science, engineering, and mathematics. He is considered one of the greatest painters of all time, and his most famous works include the Mona Lisa and The Last Supper. In addition to his art, da Vinci made significant contributions to engineering and anatomy, and his designs for machines and inventions were centuries ahead of his time. He is also known for his extensive journals and drawings, which provide valuable insights into his thoughts and ideas. Da Vinci's legacy continues to inspire and influence artists, scientists, and thinkers around the world today."
]
} }
], ],
"metadata": { "metadata": {

View File

@ -1003,6 +1003,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"oci_model_deployment_tgi_endpoint": _import_oci_md_tgi, "oci_model_deployment_tgi_endpoint": _import_oci_md_tgi,
"oci_model_deployment_vllm_endpoint": _import_oci_md_vllm, "oci_model_deployment_vllm_endpoint": _import_oci_md_vllm,
"oci_generative_ai": _import_oci_gen_ai, "oci_generative_ai": _import_oci_gen_ai,
"octoai_endpoint": _import_octoai_endpoint,
"ollama": _import_ollama, "ollama": _import_ollama,
"openai": _import_openai, "openai": _import_openai,
"openlm": _import_openlm, "openlm": _import_openlm,

View File

@ -1,166 +1,117 @@
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.language_models.llms import LLM from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms.utils import enforce_stop_tokens from langchain_community.llms.openai import BaseOpenAI
from langchain_community.utils.openai import is_openai_v1
DEFAULT_BASE_URL = "https://text.octoai.run/v1/"
DEFAULT_MODEL = "codellama-7b-instruct"
class OctoAIEndpoint(LLM): class OctoAIEndpoint(BaseOpenAI):
"""OctoAI LLM Endpoints. """OctoAI LLM Endpoints - OpenAI compatible.
OctoAIEndpoint is a class to interact with OctoAI OctoAIEndpoint is a class to interact with OctoAI Compute Service large
Compute Service large language model endpoints. language model endpoints.
To use, you should have the ``octoai`` python package installed, and the To use, you should have the environment variable ``OCTOAI_API_TOKEN`` set
environment variable ``OCTOAI_API_TOKEN`` set with your API token, or pass with your API token, or pass it as a named parameter to the constructor.
it as a named parameter to the constructor.
Example: Example:
.. code-block:: python .. code-block:: python
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
OctoAIEndpoint(
octoai_api_token="octoai-api-key", llm = OctoAIEndpoint(
endpoint_url="https://text.octoai.run/v1/chat/completions", model="llama-2-13b-chat-fp16",
model_kwargs={ max_tokens=200,
"model": "llama-2-13b-chat-fp16", presence_penalty=0,
"messages": [ temperature=0.1,
{ top_p=0.9,
"role": "system",
"content": "Below is an instruction that describes a task.
Write a response that completes the request."
}
],
"stream": False,
"max_tokens": 256,
"presence_penalty": 0,
"temperature": 0.1,
"top_p": 0.9
}
) )
""" """
endpoint_url: Optional[str] = None """Key word arguments to pass to the model."""
"""Endpoint URL to use.""" octoai_api_base: str = Field(default=DEFAULT_BASE_URL)
octoai_api_token: SecretStr = Field(default=None)
model_name: str = Field(default=DEFAULT_MODEL)
model_kwargs: Optional[dict] = None @classmethod
"""Keyword arguments to pass to the model.""" def is_lc_serializable(cls) -> bool:
return False
octoai_api_token: Optional[str] = None
"""OCTOAI API Token"""
streaming: bool = False
"""Whether to generate a stream of tokens asynchronously"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
octoai_api_token = get_from_dict_or_env(
values, "octoai_api_token", "OCTOAI_API_TOKEN"
)
values["endpoint_url"] = get_from_dict_or_env(
values, "endpoint_url", "ENDPOINT_URL"
)
values["octoai_api_token"] = octoai_api_token
return values
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _invocation_params(self) -> Dict[str, Any]:
"""Get the identifying parameters.""" """Get the parameters used to invoke the model."""
_model_kwargs = self.model_kwargs or {}
return { params: Dict[str, Any] = {
**{"endpoint_url": self.endpoint_url}, "model": self.model_name,
**{"model_kwargs": _model_kwargs}, **self._default_params,
} }
if not is_openai_v1():
params.update(
{
"api_key": self.octoai_api_token.get_secret_value(),
"api_base": self.octoai_api_base,
}
)
return {**params, **super()._invocation_params}
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
return "octoai_endpoint" return "octoai_endpoint"
def _call( @root_validator()
self, def validate_environment(cls, values: Dict) -> Dict:
prompt: str, """Validate that api key and python package exists in environment."""
stop: Optional[List[str]] = None, values["octoai_api_base"] = get_from_dict_or_env(
run_manager: Optional[CallbackManagerForLLMRun] = None, values,
**kwargs: Any, "octoai_api_base",
) -> str: "OCTOAI_API_BASE",
"""Call out to OctoAI's inference endpoint. default=DEFAULT_BASE_URL,
)
Args: values["octoai_api_token"] = convert_to_secret_str(
prompt: The prompt to pass into the model. get_from_dict_or_env(values, "octoai_api_token", "OCTOAI_API_TOKEN")
stop: Optional list of stop words to use when generating. )
values["model_name"] = get_from_dict_or_env(
Returns: values,
The string generated by the model. "model_name",
"MODEL_NAME",
""" default=DEFAULT_MODEL,
_model_kwargs = self.model_kwargs or {}
try:
from octoai import client
# Initialize the OctoAI client
octoai_client = client.Client(token=self.octoai_api_token)
if "model" in _model_kwargs:
parameter_payload = _model_kwargs
sys_msg = None
if "messages" in parameter_payload:
msgs = parameter_payload.get("messages", [])
for msg in msgs:
if msg.get("role") == "system":
sys_msg = msg.get("content")
# Reset messages list
parameter_payload["messages"] = []
# Append system message if exists
if sys_msg:
parameter_payload["messages"].append(
{"role": "system", "content": sys_msg}
) )
# Append user message try:
parameter_payload["messages"].append( import openai
{"role": "user", "content": prompt}
if is_openai_v1():
client_params = {
"api_key": values["octoai_api_token"].get_secret_value(),
"base_url": values["octoai_api_base"],
}
if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).completions
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(
**client_params
).completions
else:
values["openai_api_base"] = values["octoai_api_base"]
values["openai_api_key"] = values["octoai_api_token"].get_secret_value()
values["client"] = openai.Completion
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
) )
# Send the request using the OctoAI client if "endpoint_url" in values["model_kwargs"]:
try: raise ValueError(
output = octoai_client.infer(self.endpoint_url, parameter_payload) "`endpoint_url` was deprecated, please use `octoai_api_base`."
if output and "choices" in output and len(output["choices"]) > 0: )
text = output["choices"][0].get("message", {}).get("content")
else:
text = "Error: Invalid response format or empty choices."
except Exception as e:
text = f"Error during API call: {str(e)}"
else: return values
# Prepare the payload JSON
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
# Send the request using the OctoAI client
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
text = resp_json["generated_text"]
except Exception as e:
# Handle any errors raised by the inference endpoint
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
if stop is not None:
# Apply stop tokens when making calls to OctoAI
text = enforce_stop_tokens(text, stop)
return text

View File

@ -1,58 +1,11 @@
"""Test OctoAI API wrapper.""" """Test OctoAI API wrapper."""
from pathlib import Path
import pytest
from langchain_community.llms.loading import load_llm
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
from tests.integration_tests.llms.utils import assert_llm_equality
def test_octoai_endpoint_text_generation() -> None: def test_octoai_endpoint_call() -> None:
"""Test valid call to OctoAI text generation model.""" """Test valid call to OctoAI endpoint."""
llm = OctoAIEndpoint( llm = OctoAIEndpoint()
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
octoai_api_token="<octoai_api_token>",
model_kwargs={
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
"seed": None,
"stop": [],
},
)
output = llm("Which state is Los Angeles in?") output = llm("Which state is Los Angeles in?")
print(output) # noqa: T201 print(output) # noqa: T201
assert isinstance(output, str) assert isinstance(output, str)
def test_octoai_endpoint_call_error() -> None:
"""Test valid call to OctoAI that errors."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
model_kwargs={"max_new_tokens": -1},
)
with pytest.raises(ValueError):
llm("Which state is Los Angeles in?")
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
"""Test saving/loading an OctoAIHub LLM."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
octoai_api_token="<octoai_api_token>",
model_kwargs={
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
"seed": None,
"stop": [],
},
)
llm.save(file_path=tmp_path / "octoai.yaml")
loaded_llm = load_llm(tmp_path / "octoai.yaml")
assert_llm_equality(llm, loaded_llm)