mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 02:33:19 +00:00
Adds support for llama2 and fixes MPT-7b url (#11465)
- **Description:** This is an update to OctoAI LLM provider that adds support for llama2 endpoints hosted on OctoAI and updates MPT-7b url with the current one. @baskaryan Thanks! --------- Co-authored-by: ML Wiz <bassemgeorgi@gmail.com>
This commit is contained in:
parent
0bff399af1
commit
5451b724fc
@ -33,7 +33,7 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\"\n",
|
||||
"os.environ[\"ENDPOINT_URL\"] = \"https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate\""
|
||||
"os.environ[\"ENDPOINT_URL\"] = \"https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -23,7 +23,7 @@ class OctoAIEndpoint(LLM):
|
||||
from langchain.llms.octoai_endpoint import OctoAIEndpoint
|
||||
OctoAIEndpoint(
|
||||
octoai_api_token="octoai-api-key",
|
||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
||||
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
|
||||
model_kwargs={
|
||||
"max_new_tokens": 200,
|
||||
"temperature": 0.75,
|
||||
@ -34,6 +34,24 @@ class OctoAIEndpoint(LLM):
|
||||
},
|
||||
)
|
||||
|
||||
from langchain.llms.octoai_endpoint import OctoAIEndpoint
|
||||
OctoAIEndpoint(
|
||||
octoai_api_token="octoai-api-key",
|
||||
endpoint_url="https://llama-2-7b-chat-demo-kk0powt97tmb.octoai.run/v1/chat/completions",
|
||||
model_kwargs={
|
||||
"model": "llama-2-7b-chat",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Below is an instruction that describes a task.
|
||||
Write a response that completes the request."
|
||||
}
|
||||
],
|
||||
"stream": False,
|
||||
"max_tokens": 256
|
||||
}
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
endpoint_url: Optional[str] = None
|
||||
@ -45,6 +63,9 @@ class OctoAIEndpoint(LLM):
|
||||
octoai_api_token: Optional[str] = None
|
||||
"""OCTOAI API Token"""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to generate a stream of tokens asynchronously"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@ -96,18 +117,27 @@ class OctoAIEndpoint(LLM):
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
# Prepare the payload JSON
|
||||
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
||||
|
||||
try:
|
||||
# Initialize the OctoAI client
|
||||
from octoai import client
|
||||
|
||||
octoai_client = client.Client(token=self.octoai_api_token)
|
||||
|
||||
# Send the request using the OctoAI client
|
||||
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
|
||||
text = resp_json["generated_text"]
|
||||
if "model" in _model_kwargs and "llama-2" in _model_kwargs["model"]:
|
||||
parameter_payload = _model_kwargs
|
||||
parameter_payload["messages"].append(
|
||||
{"role": "user", "content": prompt}
|
||||
)
|
||||
# Send the request using the OctoAI client
|
||||
output = octoai_client.infer(self.endpoint_url, parameter_payload)
|
||||
text = output.get("choices")[0].get("message").get("content")
|
||||
else:
|
||||
# Prepare the payload JSON
|
||||
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
||||
|
||||
# Send the request using the OctoAI client
|
||||
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
|
||||
text = resp_json["generated_text"]
|
||||
|
||||
except Exception as e:
|
||||
# Handle any errors raised by the inference endpoint
|
||||
|
@ -12,7 +12,7 @@ from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
def test_octoai_endpoint_text_generation() -> None:
|
||||
"""Test valid call to OctoAI text generation model."""
|
||||
llm = OctoAIEndpoint(
|
||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
||||
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
|
||||
octoai_api_token="<octoai_api_token>",
|
||||
model_kwargs={
|
||||
"max_new_tokens": 200,
|
||||
@ -32,7 +32,7 @@ def test_octoai_endpoint_text_generation() -> None:
|
||||
def test_octoai_endpoint_call_error() -> None:
|
||||
"""Test valid call to OctoAI that errors."""
|
||||
llm = OctoAIEndpoint(
|
||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
||||
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
|
||||
model_kwargs={"max_new_tokens": -1},
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
@ -42,7 +42,7 @@ def test_octoai_endpoint_call_error() -> None:
|
||||
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an OctoAIHub LLM."""
|
||||
llm = OctoAIEndpoint(
|
||||
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
|
||||
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
|
||||
octoai_api_token="<octoai_api_token>",
|
||||
model_kwargs={
|
||||
"max_new_tokens": 200,
|
||||
|
Loading…
Reference in New Issue
Block a user