community[minor]: Update OctoAI LLM, Embedding and documentation (#16710)

This PR includes updates for OctoAI integrations:
- The LLM class was updated to fix a bug that occurs with multiple
sequential calls
- The Embedding class was updated to support the new GTE-Large endpoint
released on OctoAI lately
- The documentation jupyter notebook was updated to reflect using the
new LLM sdk
Thank you!
This commit is contained in:
Bassem Yacoube
2024-01-29 13:57:17 -08:00
committed by GitHub
parent 6d6226d96d
commit 85e93e05ed
3 changed files with 75 additions and 46 deletions

View File

@@ -24,23 +24,9 @@ class OctoAIEndpoint(LLM):
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
OctoAIEndpoint(
octoai_api_token="octoai-api-key",
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
endpoint_url="https://text.octoai.run/v1/chat/completions",
model_kwargs={
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
"seed": None,
"stop": [],
},
)
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
OctoAIEndpoint(
octoai_api_token="octoai-api-key",
endpoint_url="https://llama-2-7b-chat-demo-kk0powt97tmb.octoai.run/v1/chat/completions",
model_kwargs={
"model": "llama-2-7b-chat",
"model": "llama-2-13b-chat-fp16",
"messages": [
{
"role": "system",
@@ -49,7 +35,10 @@ class OctoAIEndpoint(LLM):
}
],
"stream": False,
"max_tokens": 256
"max_tokens": 256,
"presence_penalty": 0,
"temperature": 0.1,
"top_p": 0.9
}
)
@@ -119,19 +108,45 @@ class OctoAIEndpoint(LLM):
_model_kwargs = self.model_kwargs or {}
try:
# Initialize the OctoAI client
from octoai import client
# Initialize the OctoAI client
octoai_client = client.Client(token=self.octoai_api_token)
if "model" in _model_kwargs:
parameter_payload = _model_kwargs
sys_msg = None
if "messages" in parameter_payload:
msgs = parameter_payload.get("messages", [])
for msg in msgs:
if msg.get("role") == "system":
sys_msg = msg.get("content")
# Reset messages list
parameter_payload["messages"] = []
# Append system message if exists
if sys_msg:
parameter_payload["messages"].append(
{"role": "system", "content": sys_msg}
)
# Append user message
parameter_payload["messages"].append(
{"role": "user", "content": prompt}
)
# Send the request using the OctoAI client
output = octoai_client.infer(self.endpoint_url, parameter_payload)
text = output.get("choices")[0].get("message").get("content")
try:
output = octoai_client.infer(self.endpoint_url, parameter_payload)
if output and "choices" in output and len(output["choices"]) > 0:
text = output["choices"][0].get("message", {}).get("content")
else:
text = "Error: Invalid response format or empty choices."
except Exception as e:
text = f"Error during API call: {str(e)}"
else:
# Prepare the payload JSON
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}