mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 10:43:36 +00:00
Adds support for llama2 and fixes MPT-7b url (#11465)
- **Description:** This is an update to OctoAI LLM provider that adds support for llama2 endpoints hosted on OctoAI and updates MPT-7b url with the current one. @baskaryan Thanks! --------- Co-authored-by: ML Wiz <bassemgeorgi@gmail.com>
This commit is contained in:
parent
0bff399af1
commit
5451b724fc
@ -33,7 +33,7 @@
|
|||||||
"import os\n",
|
"import os\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\"\n",
|
"os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\"\n",
|
||||||
"os.environ[\"ENDPOINT_URL\"] = \"https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate\""
|
"os.environ[\"ENDPOINT_URL\"] = \"https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -23,7 +23,7 @@ class OctoAIEndpoint(LLM):
|
|||||||
from langchain.llms.octoai_endpoint import OctoAIEndpoint
|
from langchain.llms.octoai_endpoint import OctoAIEndpoint
|
||||||
OctoAIEndpoint(
|
OctoAIEndpoint(
|
||||||
octoai_api_token="octoai-api-key",
|
octoai_api_token="octoai-api-key",
|
||||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
|
||||||
model_kwargs={
|
model_kwargs={
|
||||||
"max_new_tokens": 200,
|
"max_new_tokens": 200,
|
||||||
"temperature": 0.75,
|
"temperature": 0.75,
|
||||||
@ -34,6 +34,24 @@ class OctoAIEndpoint(LLM):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from langchain.llms.octoai_endpoint import OctoAIEndpoint
|
||||||
|
OctoAIEndpoint(
|
||||||
|
octoai_api_token="octoai-api-key",
|
||||||
|
endpoint_url="https://llama-2-7b-chat-demo-kk0powt97tmb.octoai.run/v1/chat/completions",
|
||||||
|
model_kwargs={
|
||||||
|
"model": "llama-2-7b-chat",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Below is an instruction that describes a task.
|
||||||
|
Write a response that completes the request."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
"max_tokens": 256
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
endpoint_url: Optional[str] = None
|
endpoint_url: Optional[str] = None
|
||||||
@ -45,6 +63,9 @@ class OctoAIEndpoint(LLM):
|
|||||||
octoai_api_token: Optional[str] = None
|
octoai_api_token: Optional[str] = None
|
||||||
"""OCTOAI API Token"""
|
"""OCTOAI API Token"""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to generate a stream of tokens asynchronously"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
@ -96,18 +117,27 @@ class OctoAIEndpoint(LLM):
|
|||||||
"""
|
"""
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
|
||||||
# Prepare the payload JSON
|
|
||||||
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize the OctoAI client
|
# Initialize the OctoAI client
|
||||||
from octoai import client
|
from octoai import client
|
||||||
|
|
||||||
octoai_client = client.Client(token=self.octoai_api_token)
|
octoai_client = client.Client(token=self.octoai_api_token)
|
||||||
|
|
||||||
# Send the request using the OctoAI client
|
if "model" in _model_kwargs and "llama-2" in _model_kwargs["model"]:
|
||||||
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
|
parameter_payload = _model_kwargs
|
||||||
text = resp_json["generated_text"]
|
parameter_payload["messages"].append(
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
)
|
||||||
|
# Send the request using the OctoAI client
|
||||||
|
output = octoai_client.infer(self.endpoint_url, parameter_payload)
|
||||||
|
text = output.get("choices")[0].get("message").get("content")
|
||||||
|
else:
|
||||||
|
# Prepare the payload JSON
|
||||||
|
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
||||||
|
|
||||||
|
# Send the request using the OctoAI client
|
||||||
|
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
|
||||||
|
text = resp_json["generated_text"]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle any errors raised by the inference endpoint
|
# Handle any errors raised by the inference endpoint
|
||||||
|
@ -12,7 +12,7 @@ from tests.integration_tests.llms.utils import assert_llm_equality
|
|||||||
def test_octoai_endpoint_text_generation() -> None:
|
def test_octoai_endpoint_text_generation() -> None:
|
||||||
"""Test valid call to OctoAI text generation model."""
|
"""Test valid call to OctoAI text generation model."""
|
||||||
llm = OctoAIEndpoint(
|
llm = OctoAIEndpoint(
|
||||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
|
||||||
octoai_api_token="<octoai_api_token>",
|
octoai_api_token="<octoai_api_token>",
|
||||||
model_kwargs={
|
model_kwargs={
|
||||||
"max_new_tokens": 200,
|
"max_new_tokens": 200,
|
||||||
@ -32,7 +32,7 @@ def test_octoai_endpoint_text_generation() -> None:
|
|||||||
def test_octoai_endpoint_call_error() -> None:
|
def test_octoai_endpoint_call_error() -> None:
|
||||||
"""Test valid call to OctoAI that errors."""
|
"""Test valid call to OctoAI that errors."""
|
||||||
llm = OctoAIEndpoint(
|
llm = OctoAIEndpoint(
|
||||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
|
||||||
model_kwargs={"max_new_tokens": -1},
|
model_kwargs={"max_new_tokens": -1},
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -42,7 +42,7 @@ def test_octoai_endpoint_call_error() -> None:
|
|||||||
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
|
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
|
||||||
"""Test saving/loading an OctoAIHub LLM."""
|
"""Test saving/loading an OctoAIHub LLM."""
|
||||||
llm = OctoAIEndpoint(
|
llm = OctoAIEndpoint(
|
||||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
|
||||||
octoai_api_token="<octoai_api_token>",
|
octoai_api_token="<octoai_api_token>",
|
||||||
model_kwargs={
|
model_kwargs={
|
||||||
"max_new_tokens": 200,
|
"max_new_tokens": 200,
|
||||||
|
Loading…
Reference in New Issue
Block a user