mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
community[patch]: update OctoAI endpoint to subclass BaseOpenAI (#19757)
This PR updates OctoAIEndpoint LLM to subclass BaseOpenAI as OctoAI is an OpenAI-compatible service. The documentation and tests have also been updated.
This commit is contained in:
parent
0c95ddbcd8
commit
54d388d898
@ -18,7 +18,7 @@
|
||||
" \n",
|
||||
"2. Paste your API key in in the code cell below.\n",
|
||||
"\n",
|
||||
"Note: If you want to use a different LLM model, you can containerize the model and make a custom OctoAI endpoint yourself, by following [Build a Container from Python](https://octo.ai/docs/bring-your-own-model/advanced-build-a-container-from-scratch-in-python) and [Create a Custom Endpoint from a Container](https://octo.ai/docs/bring-your-own-model/create-custom-endpoints-from-a-container/create-custom-endpoints-from-a-container) and then update your Endpoint URL in the code cell below.\n"
|
||||
"Note: If you want to use a different LLM model, you can containerize the model and make a custom OctoAI endpoint yourself, by following [Build a Container from Python](https://octo.ai/docs/bring-your-own-model/advanced-build-a-container-from-scratch-in-python) and [Create a Custom Endpoint from a Container](https://octo.ai/docs/bring-your-own-model/create-custom-endpoints-from-a-container/create-custom-endpoints-from-a-container) and then updating your `OCTOAI_API_BASE` environment variable.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -29,8 +29,7 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\"\n",
|
||||
"os.environ[\"ENDPOINT_URL\"] = \"https://text.octoai.run/v1/chat/completions\""
|
||||
"os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -68,44 +67,33 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OctoAIEndpoint(\n",
|
||||
" model_kwargs={\n",
|
||||
" \"model\": \"llama-2-13b-chat-fp16\",\n",
|
||||
" \"max_tokens\": 128,\n",
|
||||
" \"presence_penalty\": 0,\n",
|
||||
" \"temperature\": 0.1,\n",
|
||||
" \"top_p\": 0.9,\n",
|
||||
" \"messages\": [\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" \"content\": \"You are a helpful assistant. Keep your responses limited to one short paragraph if possible.\",\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
" model=\"llama-2-13b-chat-fp16\",\n",
|
||||
" max_tokens=200,\n",
|
||||
" presence_penalty=0,\n",
|
||||
" temperature=0.1,\n",
|
||||
" top_p=0.9,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" Sure thing! Here's my response:\n",
|
||||
"\n",
|
||||
"Leonardo da Vinci was a true Renaissance man - an Italian polymath who excelled in various fields, including painting, sculpture, engineering, mathematics, anatomy, and geology. He is widely considered one of the greatest painters of all time, and his inventive and innovative works continue to inspire and influence artists and thinkers to this day. Some of his most famous works include the Mona Lisa, The Last Supper, and Vitruvian Man. \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"Who was leonardo davinci?\"\n",
|
||||
"question = \"Who was Leonardo da Vinci?\"\n",
|
||||
"\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"\n",
|
||||
"print(llm_chain.run(question))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Leonardo da Vinci was a true Renaissance man. He was born in 1452 in Vinci, Italy and was known for his work in various fields, including art, science, engineering, and mathematics. He is considered one of the greatest painters of all time, and his most famous works include the Mona Lisa and The Last Supper. In addition to his art, da Vinci made significant contributions to engineering and anatomy, and his designs for machines and inventions were centuries ahead of his time. He is also known for his extensive journals and drawings, which provide valuable insights into his thoughts and ideas. Da Vinci's legacy continues to inspire and influence artists, scientists, and thinkers around the world today."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -1003,6 +1003,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"oci_model_deployment_tgi_endpoint": _import_oci_md_tgi,
|
||||
"oci_model_deployment_vllm_endpoint": _import_oci_md_vllm,
|
||||
"oci_generative_ai": _import_oci_gen_ai,
|
||||
"octoai_endpoint": _import_octoai_endpoint,
|
||||
"ollama": _import_ollama,
|
||||
"openai": _import_openai,
|
||||
"openlm": _import_openlm,
|
||||
|
@ -1,166 +1,117 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
from langchain_community.llms.openai import BaseOpenAI
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
|
||||
DEFAULT_BASE_URL = "https://text.octoai.run/v1/"
|
||||
DEFAULT_MODEL = "codellama-7b-instruct"
|
||||
|
||||
|
||||
class OctoAIEndpoint(LLM):
|
||||
"""OctoAI LLM Endpoints.
|
||||
class OctoAIEndpoint(BaseOpenAI):
|
||||
"""OctoAI LLM Endpoints - OpenAI compatible.
|
||||
|
||||
OctoAIEndpoint is a class to interact with OctoAI
|
||||
Compute Service large language model endpoints.
|
||||
OctoAIEndpoint is a class to interact with OctoAI Compute Service large
|
||||
language model endpoints.
|
||||
|
||||
To use, you should have the ``octoai`` python package installed, and the
|
||||
environment variable ``OCTOAI_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
To use, you should have the environment variable ``OCTOAI_API_TOKEN`` set
|
||||
with your API token, or pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
|
||||
OctoAIEndpoint(
|
||||
octoai_api_token="octoai-api-key",
|
||||
endpoint_url="https://text.octoai.run/v1/chat/completions",
|
||||
model_kwargs={
|
||||
"model": "llama-2-13b-chat-fp16",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Below is an instruction that describes a task.
|
||||
Write a response that completes the request."
|
||||
}
|
||||
],
|
||||
"stream": False,
|
||||
"max_tokens": 256,
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.9
|
||||
}
|
||||
|
||||
llm = OctoAIEndpoint(
|
||||
model="llama-2-13b-chat-fp16",
|
||||
max_tokens=200,
|
||||
presence_penalty=0,
|
||||
temperature=0.1,
|
||||
top_p=0.9,
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
endpoint_url: Optional[str] = None
|
||||
"""Endpoint URL to use."""
|
||||
"""Key word arguments to pass to the model."""
|
||||
octoai_api_base: str = Field(default=DEFAULT_BASE_URL)
|
||||
octoai_api_token: SecretStr = Field(default=None)
|
||||
model_name: str = Field(default=DEFAULT_MODEL)
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
octoai_api_token: Optional[str] = None
|
||||
"""OCTOAI API Token"""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to generate a stream of tokens asynchronously"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
octoai_api_token = get_from_dict_or_env(
|
||||
values, "octoai_api_token", "OCTOAI_API_TOKEN"
|
||||
)
|
||||
values["endpoint_url"] = get_from_dict_or_env(
|
||||
values, "endpoint_url", "ENDPOINT_URL"
|
||||
)
|
||||
|
||||
values["octoai_api_token"] = octoai_api_token
|
||||
return values
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"endpoint_url": self.endpoint_url},
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
**self._default_params,
|
||||
}
|
||||
if not is_openai_v1():
|
||||
params.update(
|
||||
{
|
||||
"api_key": self.octoai_api_token.get_secret_value(),
|
||||
"api_base": self.octoai_api_base,
|
||||
}
|
||||
)
|
||||
|
||||
return {**params, **super()._invocation_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "octoai_endpoint"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to OctoAI's inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["octoai_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"octoai_api_base",
|
||||
"OCTOAI_API_BASE",
|
||||
default=DEFAULT_BASE_URL,
|
||||
)
|
||||
values["octoai_api_token"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "octoai_api_token", "OCTOAI_API_TOKEN")
|
||||
)
|
||||
values["model_name"] = get_from_dict_or_env(
|
||||
values,
|
||||
"model_name",
|
||||
"MODEL_NAME",
|
||||
default=DEFAULT_MODEL,
|
||||
)
|
||||
|
||||
try:
|
||||
from octoai import client
|
||||
|
||||
# Initialize the OctoAI client
|
||||
octoai_client = client.Client(token=self.octoai_api_token)
|
||||
|
||||
if "model" in _model_kwargs:
|
||||
parameter_payload = _model_kwargs
|
||||
|
||||
sys_msg = None
|
||||
if "messages" in parameter_payload:
|
||||
msgs = parameter_payload.get("messages", [])
|
||||
for msg in msgs:
|
||||
if msg.get("role") == "system":
|
||||
sys_msg = msg.get("content")
|
||||
|
||||
# Reset messages list
|
||||
parameter_payload["messages"] = []
|
||||
|
||||
# Append system message if exists
|
||||
if sys_msg:
|
||||
parameter_payload["messages"].append(
|
||||
{"role": "system", "content": sys_msg}
|
||||
)
|
||||
|
||||
# Append user message
|
||||
parameter_payload["messages"].append(
|
||||
{"role": "user", "content": prompt}
|
||||
)
|
||||
|
||||
# Send the request using the OctoAI client
|
||||
try:
|
||||
output = octoai_client.infer(self.endpoint_url, parameter_payload)
|
||||
if output and "choices" in output and len(output["choices"]) > 0:
|
||||
text = output["choices"][0].get("message", {}).get("content")
|
||||
else:
|
||||
text = "Error: Invalid response format or empty choices."
|
||||
except Exception as e:
|
||||
text = f"Error during API call: {str(e)}"
|
||||
import openai
|
||||
|
||||
if is_openai_v1():
|
||||
client_params = {
|
||||
"api_key": values["octoai_api_token"].get_secret_value(),
|
||||
"base_url": values["octoai_api_base"],
|
||||
}
|
||||
if not values.get("client"):
|
||||
values["client"] = openai.OpenAI(**client_params).completions
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = openai.AsyncOpenAI(
|
||||
**client_params
|
||||
).completions
|
||||
else:
|
||||
# Prepare the payload JSON
|
||||
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
||||
values["openai_api_base"] = values["octoai_api_base"]
|
||||
values["openai_api_key"] = values["octoai_api_token"].get_secret_value()
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
|
||||
# Send the request using the OctoAI client
|
||||
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
|
||||
text = resp_json["generated_text"]
|
||||
if "endpoint_url" in values["model_kwargs"]:
|
||||
raise ValueError(
|
||||
"`endpoint_url` was deprecated, please use `octoai_api_base`."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Handle any errors raised by the inference endpoint
|
||||
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
||||
|
||||
if stop is not None:
|
||||
# Apply stop tokens when making calls to OctoAI
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
||||
return values
|
||||
|
@ -1,58 +1,11 @@
|
||||
"""Test OctoAI API wrapper."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.llms.loading import load_llm
|
||||
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
|
||||
from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
|
||||
|
||||
def test_octoai_endpoint_text_generation() -> None:
|
||||
"""Test valid call to OctoAI text generation model."""
|
||||
llm = OctoAIEndpoint(
|
||||
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
|
||||
octoai_api_token="<octoai_api_token>",
|
||||
model_kwargs={
|
||||
"max_new_tokens": 200,
|
||||
"temperature": 0.75,
|
||||
"top_p": 0.95,
|
||||
"repetition_penalty": 1,
|
||||
"seed": None,
|
||||
"stop": [],
|
||||
},
|
||||
)
|
||||
|
||||
def test_octoai_endpoint_call() -> None:
|
||||
"""Test valid call to OctoAI endpoint."""
|
||||
llm = OctoAIEndpoint()
|
||||
output = llm("Which state is Los Angeles in?")
|
||||
print(output) # noqa: T201
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_octoai_endpoint_call_error() -> None:
|
||||
"""Test valid call to OctoAI that errors."""
|
||||
llm = OctoAIEndpoint(
|
||||
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
|
||||
model_kwargs={"max_new_tokens": -1},
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
llm("Which state is Los Angeles in?")
|
||||
|
||||
|
||||
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an OctoAIHub LLM."""
|
||||
llm = OctoAIEndpoint(
|
||||
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
|
||||
octoai_api_token="<octoai_api_token>",
|
||||
model_kwargs={
|
||||
"max_new_tokens": 200,
|
||||
"temperature": 0.75,
|
||||
"top_p": 0.95,
|
||||
"repetition_penalty": 1,
|
||||
"seed": None,
|
||||
"stop": [],
|
||||
},
|
||||
)
|
||||
llm.save(file_path=tmp_path / "octoai.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "octoai.yaml")
|
||||
assert_llm_equality(llm, loaded_llm)
|
||||
|
Loading…
Reference in New Issue
Block a user