mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 02:50:47 +00:00
community[minor]: Update OctoAI LLM, Embedding and documentation (#16710)
This PR includes updates for OctoAI integrations: - The LLM class was updated to fix a bug that occurs with multiple sequential calls - The Embedding class was updated to support the new GTE-Large endpoint released on OctoAI lately - The documentation jupyter notebook was updated to reflect using the new LLM sdk Thank you!
This commit is contained in:
@@ -24,23 +24,9 @@ class OctoAIEndpoint(LLM):
|
||||
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
|
||||
OctoAIEndpoint(
|
||||
octoai_api_token="octoai-api-key",
|
||||
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
|
||||
endpoint_url="https://text.octoai.run/v1/chat/completions",
|
||||
model_kwargs={
|
||||
"max_new_tokens": 200,
|
||||
"temperature": 0.75,
|
||||
"top_p": 0.95,
|
||||
"repetition_penalty": 1,
|
||||
"seed": None,
|
||||
"stop": [],
|
||||
},
|
||||
)
|
||||
|
||||
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
|
||||
OctoAIEndpoint(
|
||||
octoai_api_token="octoai-api-key",
|
||||
endpoint_url="https://llama-2-7b-chat-demo-kk0powt97tmb.octoai.run/v1/chat/completions",
|
||||
model_kwargs={
|
||||
"model": "llama-2-7b-chat",
|
||||
"model": "llama-2-13b-chat-fp16",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -49,7 +35,10 @@ class OctoAIEndpoint(LLM):
|
||||
}
|
||||
],
|
||||
"stream": False,
|
||||
"max_tokens": 256
|
||||
"max_tokens": 256,
|
||||
"presence_penalty": 0,
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.9
|
||||
}
|
||||
)
|
||||
|
||||
@@ -119,19 +108,45 @@ class OctoAIEndpoint(LLM):
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
try:
|
||||
# Initialize the OctoAI client
|
||||
from octoai import client
|
||||
|
||||
# Initialize the OctoAI client
|
||||
octoai_client = client.Client(token=self.octoai_api_token)
|
||||
|
||||
if "model" in _model_kwargs:
|
||||
parameter_payload = _model_kwargs
|
||||
|
||||
sys_msg = None
|
||||
if "messages" in parameter_payload:
|
||||
msgs = parameter_payload.get("messages", [])
|
||||
for msg in msgs:
|
||||
if msg.get("role") == "system":
|
||||
sys_msg = msg.get("content")
|
||||
|
||||
# Reset messages list
|
||||
parameter_payload["messages"] = []
|
||||
|
||||
# Append system message if exists
|
||||
if sys_msg:
|
||||
parameter_payload["messages"].append(
|
||||
{"role": "system", "content": sys_msg}
|
||||
)
|
||||
|
||||
# Append user message
|
||||
parameter_payload["messages"].append(
|
||||
{"role": "user", "content": prompt}
|
||||
)
|
||||
|
||||
# Send the request using the OctoAI client
|
||||
output = octoai_client.infer(self.endpoint_url, parameter_payload)
|
||||
text = output.get("choices")[0].get("message").get("content")
|
||||
try:
|
||||
output = octoai_client.infer(self.endpoint_url, parameter_payload)
|
||||
if output and "choices" in output and len(output["choices"]) > 0:
|
||||
text = output["choices"][0].get("message", {}).get("content")
|
||||
else:
|
||||
text = "Error: Invalid response format or empty choices."
|
||||
except Exception as e:
|
||||
text = f"Error during API call: {str(e)}"
|
||||
|
||||
else:
|
||||
# Prepare the payload JSON
|
||||
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
||||
|
Reference in New Issue
Block a user