mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
Fixes to the Nebula LLM Integration (#8918)
This addresses some issues with introducing the Nebula LLM to LangChain in this PR: https://github.com/langchain-ai/langchain/pull/8876 This fixes the following: - Removes `SYMBLAI` from variable names - Fixes bug with `Bearer` for the API KEY Thanks again in advance for your help! cc: @hwchase17, @baskaryan --------- Co-authored-by: dvonthenen <david.vonthenen@gmail.com>
This commit is contained in:
parent
d1e305028f
commit
bf4a112aa6
@ -34,9 +34,9 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"SYMBLAI_NEBULA_SERVICE_URL\"] = SYMBLAI_NEBULA_SERVICE_URL\n",
|
||||
"os.environ[\"SYMBLAI_NEBULA_SERVICE_PATH\"] = SYMBLAI_NEBULA_SERVICE_PATH\n",
|
||||
"os.environ[\"SYMBLAI_NEBULA_SERVICE_TOKEN\"] = SYMBLAI_NEBULA_SERVICE_TOKEN"
|
||||
"os.environ[\"NEBULA_SERVICE_URL\"] = NEBULA_SERVICE_URL\n",
|
||||
"os.environ[\"NEBULA_SERVICE_PATH\"] = NEBULA_SERVICE_PATH\n",
|
||||
"os.environ[\"NEBULA_SERVICE_API_KEY\"] = NEBULA_SERVICE_API_KEY"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -9,8 +9,8 @@ from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
DEFAULT_SYMBLAI_NEBULA_SERVICE_URL = "https://api-nebula.symbl.ai"
|
||||
DEFAULT_SYMBLAI_NEBULA_SERVICE_PATH = "/v1/model/generate"
|
||||
DEFAULT_NEBULA_SERVICE_URL = "https://api-nebula.symbl.ai"
|
||||
DEFAULT_NEBULA_SERVICE_PATH = "/v1/model/generate"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -18,8 +18,8 @@ logger = logging.getLogger(__name__)
|
||||
class Nebula(LLM):
|
||||
"""Nebula Service models.
|
||||
|
||||
To use, you should have the environment variable ``SYMBLAI_NEBULA_SERVICE_URL``,
|
||||
``SYMBLAI_NEBULA_SERVICE_PATH`` and ``SYMBLAI_NEBULA_SERVICE_TOKEN`` set with your Nebula
|
||||
To use, you should have the environment variable ``NEBULA_SERVICE_URL``,
|
||||
``NEBULA_SERVICE_PATH`` and ``NEBULA_SERVICE_API_KEY`` set with your Nebula
|
||||
Service, or pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
@ -30,21 +30,8 @@ class Nebula(LLM):
|
||||
nebula = Nebula(
|
||||
nebula_service_url="SERVICE_URL",
|
||||
nebula_service_path="SERVICE_ROUTE",
|
||||
nebula_service_token="SERVICE_TOKEN",
|
||||
nebula_api_key="SERVICE_TOKEN",
|
||||
)
|
||||
|
||||
# Use Ray for distributed processing
|
||||
import ray
|
||||
|
||||
prompt_list=[]
|
||||
|
||||
@ray.remote
|
||||
def send_query(llm, prompt):
|
||||
resp = llm(prompt)
|
||||
return resp
|
||||
|
||||
futures = [send_query.remote(nebula, prompt) for prompt in prompt_list]
|
||||
results = ray.get(futures)
|
||||
""" # noqa: E501
|
||||
|
||||
"""Key/value arguments to pass to the model. Reserved for future use"""
|
||||
@ -53,7 +40,7 @@ class Nebula(LLM):
|
||||
"""Optional"""
|
||||
nebula_service_url: Optional[str] = None
|
||||
nebula_service_path: Optional[str] = None
|
||||
nebula_service_token: Optional[str] = None
|
||||
nebula_api_key: Optional[str] = None
|
||||
conversation: str = ""
|
||||
return_scores: Optional[str] = "false"
|
||||
max_new_tokens: Optional[int] = 2048
|
||||
@ -69,20 +56,21 @@ class Nebula(LLM):
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
nebula_service_url = get_from_dict_or_env(
|
||||
values, "nebula_service_url", "SYMBLAI_NEBULA_SERVICE_URL"
|
||||
values,
|
||||
"nebula_service_url",
|
||||
"NEBULA_SERVICE_URL",
|
||||
DEFAULT_NEBULA_SERVICE_URL,
|
||||
)
|
||||
nebula_service_path = get_from_dict_or_env(
|
||||
values, "nebula_service_path", "SYMBLAI_NEBULA_SERVICE_PATH"
|
||||
values,
|
||||
"nebula_service_path",
|
||||
"NEBULA_SERVICE_PATH",
|
||||
DEFAULT_NEBULA_SERVICE_PATH,
|
||||
)
|
||||
nebula_service_token = get_from_dict_or_env(
|
||||
values, "nebula_service_token", "SYMBLAI_NEBULA_SERVICE_TOKEN"
|
||||
nebula_api_key = get_from_dict_or_env(
|
||||
values, "nebula_api_key", "NEBULA_SERVICE_API_KEY", ""
|
||||
)
|
||||
|
||||
if len(nebula_service_url) == 0:
|
||||
nebula_service_url = DEFAULT_SYMBLAI_NEBULA_SERVICE_URL
|
||||
if len(nebula_service_path) == 0:
|
||||
nebula_service_path = DEFAULT_SYMBLAI_NEBULA_SERVICE_PATH
|
||||
|
||||
if nebula_service_url.endswith("/"):
|
||||
nebula_service_url = nebula_service_url[:-1]
|
||||
if not nebula_service_path.startswith("/"):
|
||||
@ -94,7 +82,7 @@ class Nebula(LLM):
|
||||
nebula_service_endpoint = f"{nebula_service_url}{nebula_service_path}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"ApiKey": f"Bearer {nebula_service_token}",
|
||||
"ApiKey": "{nebula_api_key}",
|
||||
}
|
||||
requests.get(nebula_service_endpoint, headers=headers)
|
||||
except requests.exceptions.RequestException as e:
|
||||
@ -103,7 +91,7 @@ class Nebula(LLM):
|
||||
|
||||
values["nebula_service_url"] = nebula_service_url
|
||||
values["nebula_service_path"] = nebula_service_path
|
||||
values["nebula_service_token"] = nebula_service_token
|
||||
values["nebula_api_key"] = nebula_api_key
|
||||
|
||||
return values
|
||||
|
||||
@ -147,7 +135,7 @@ class Nebula(LLM):
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"ApiKey": f"Bearer {self.nebula_service_token}",
|
||||
"ApiKey": f"{self.nebula_api_key}",
|
||||
}
|
||||
|
||||
body = {
|
||||
|
Loading…
Reference in New Issue
Block a user