mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 17:33:53 +00:00
together[minor]: Update endpoint to non deprecated version (#19649)
- **Updating Together.ai Endpoint**: "langchain_together: Updated Deprecated endpoint for partner package" - Description: The inference API of together is deprecates, do replaced with completions and made corresponding changes. - Twitter handle: @dev_yashmathur --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
5ab6b39098
commit
c42ec58578
@ -1,5 +1,6 @@
|
||||
"""Wrapper around Together AI's Completion API."""
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
@ -34,13 +35,14 @@ class Together(LLM):
|
||||
model = Together(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
"""
|
||||
|
||||
base_url: str = "https://api.together.xyz/inference"
|
||||
"""Base inference API URL."""
|
||||
base_url: str = "https://api.together.xyz/v1/completions"
|
||||
"""Base completions API URL."""
|
||||
together_api_key: SecretStr
|
||||
"""Together AI API key. Get it here: https://api.together.xyz/settings/api-keys"""
|
||||
model: str
|
||||
"""Model name. Available models listed here:
|
||||
https://docs.together.ai/docs/inference-models
|
||||
Base Models: https://docs.together.ai/docs/inference-models#language-models
|
||||
Chat Models: https://docs.together.ai/docs/inference-models#chat-models
|
||||
"""
|
||||
temperature: Optional[float] = None
|
||||
"""Model temperature."""
|
||||
@ -82,13 +84,28 @@ class Together(LLM):
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_max_tokens(cls, values: Dict) -> Dict:
|
||||
"""
|
||||
The v1 completions endpoint, has max_tokens as required parameter.
|
||||
Set a default value and warn if the parameter is missing.
|
||||
"""
|
||||
if values.get("max_tokens") is None:
|
||||
warnings.warn(
|
||||
"The completions endpoint, has 'max_tokens' as required argument. "
|
||||
"The default value is being set to 200 "
|
||||
"Consider setting this value, when initializing LLM"
|
||||
)
|
||||
values["max_tokens"] = 200 # Default Value
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of model."""
|
||||
return "together"
|
||||
|
||||
def _format_output(self, output: dict) -> str:
|
||||
return output["output"]["choices"][0]["text"]
|
||||
return output["choices"][0]["text"]
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
@ -148,9 +165,6 @@ class Together(LLM):
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
if data.get("status") != "finished":
|
||||
err_msg = data.get("error", "Undefined Error")
|
||||
raise Exception(err_msg)
|
||||
|
||||
output = self._format_output(data)
|
||||
|
||||
@ -203,9 +217,5 @@ class Together(LLM):
|
||||
|
||||
response_json = await response.json()
|
||||
|
||||
if response_json.get("status") != "finished":
|
||||
err_msg = response_json.get("error", "Undefined Error")
|
||||
raise Exception(err_msg)
|
||||
|
||||
output = self._format_output(response_json)
|
||||
return output
|
||||
|
Loading…
Reference in New Issue
Block a user