diff --git a/docs/extras/integrations/llms/symblai_nebula.ipynb b/docs/extras/integrations/llms/symblai_nebula.ipynb index 587dd4b7d93..1ca58697e0e 100644 --- a/docs/extras/integrations/llms/symblai_nebula.ipynb +++ b/docs/extras/integrations/llms/symblai_nebula.ipynb @@ -34,9 +34,9 @@ "source": [ "import os\n", "\n", - "os.environ[\"SYMBLAI_NEBULA_SERVICE_URL\"] = SYMBLAI_NEBULA_SERVICE_URL\n", - "os.environ[\"SYMBLAI_NEBULA_SERVICE_PATH\"] = SYMBLAI_NEBULA_SERVICE_PATH\n", - "os.environ[\"SYMBLAI_NEBULA_SERVICE_TOKEN\"] = SYMBLAI_NEBULA_SERVICE_TOKEN" + "os.environ[\"NEBULA_SERVICE_URL\"] = NEBULA_SERVICE_URL\n", + "os.environ[\"NEBULA_SERVICE_PATH\"] = NEBULA_SERVICE_PATH\n", + "os.environ[\"NEBULA_SERVICE_API_KEY\"] = NEBULA_SERVICE_API_KEY" ] }, { diff --git a/libs/langchain/langchain/llms/symblai_nebula.py b/libs/langchain/langchain/llms/symblai_nebula.py index aedeec31790..cec64b55168 100644 --- a/libs/langchain/langchain/llms/symblai_nebula.py +++ b/libs/langchain/langchain/llms/symblai_nebula.py @@ -9,8 +9,8 @@ from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env -DEFAULT_SYMBLAI_NEBULA_SERVICE_URL = "https://api-nebula.symbl.ai" -DEFAULT_SYMBLAI_NEBULA_SERVICE_PATH = "/v1/model/generate" +DEFAULT_NEBULA_SERVICE_URL = "https://api-nebula.symbl.ai" +DEFAULT_NEBULA_SERVICE_PATH = "/v1/model/generate" logger = logging.getLogger(__name__) @@ -18,8 +18,8 @@ logger = logging.getLogger(__name__) class Nebula(LLM): """Nebula Service models. - To use, you should have the environment variable ``SYMBLAI_NEBULA_SERVICE_URL``, - ``SYMBLAI_NEBULA_SERVICE_PATH`` and ``SYMBLAI_NEBULA_SERVICE_TOKEN`` set with your Nebula + To use, you should have the environment variable ``NEBULA_SERVICE_URL``, + ``NEBULA_SERVICE_PATH`` and ``NEBULA_SERVICE_API_KEY`` set with your Nebula Service, or pass it as a named parameter to the constructor. Example: @@ -30,21 +30,8 @@ class Nebula(LLM): nebula = Nebula( nebula_service_url="SERVICE_URL", nebula_service_path="SERVICE_ROUTE", - nebula_service_token="SERVICE_TOKEN", + nebula_api_key="SERVICE_TOKEN", ) - - # Use Ray for distributed processing - import ray - - prompt_list=[] - - @ray.remote - def send_query(llm, prompt): - resp = llm(prompt) - return resp - - futures = [send_query.remote(nebula, prompt) for prompt in prompt_list] - results = ray.get(futures) """ # noqa: E501 """Key/value arguments to pass to the model. Reserved for future use""" @@ -53,7 +40,7 @@ class Nebula(LLM): """Optional""" nebula_service_url: Optional[str] = None nebula_service_path: Optional[str] = None - nebula_service_token: Optional[str] = None + nebula_api_key: Optional[str] = None conversation: str = "" return_scores: Optional[str] = "false" max_new_tokens: Optional[int] = 2048 @@ -69,20 +56,21 @@ class Nebula(LLM): def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" nebula_service_url = get_from_dict_or_env( - values, "nebula_service_url", "SYMBLAI_NEBULA_SERVICE_URL" + values, + "nebula_service_url", + "NEBULA_SERVICE_URL", + DEFAULT_NEBULA_SERVICE_URL, ) nebula_service_path = get_from_dict_or_env( - values, "nebula_service_path", "SYMBLAI_NEBULA_SERVICE_PATH" + values, + "nebula_service_path", + "NEBULA_SERVICE_PATH", + DEFAULT_NEBULA_SERVICE_PATH, ) - nebula_service_token = get_from_dict_or_env( - values, "nebula_service_token", "SYMBLAI_NEBULA_SERVICE_TOKEN" + nebula_api_key = get_from_dict_or_env( + values, "nebula_api_key", "NEBULA_SERVICE_API_KEY", "" ) - if len(nebula_service_url) == 0: - nebula_service_url = DEFAULT_SYMBLAI_NEBULA_SERVICE_URL - if len(nebula_service_path) == 0: - nebula_service_path = DEFAULT_SYMBLAI_NEBULA_SERVICE_PATH - if nebula_service_url.endswith("/"): nebula_service_url = nebula_service_url[:-1] if not nebula_service_path.startswith("/"): @@ -94,7 +82,7 @@ class Nebula(LLM): nebula_service_endpoint = f"{nebula_service_url}{nebula_service_path}" headers = { "Content-Type": "application/json", - "ApiKey": f"Bearer {nebula_service_token}", + "ApiKey": "{nebula_api_key}", } requests.get(nebula_service_endpoint, headers=headers) except requests.exceptions.RequestException as e: @@ -103,7 +91,7 @@ class Nebula(LLM): values["nebula_service_url"] = nebula_service_url values["nebula_service_path"] = nebula_service_path - values["nebula_service_token"] = nebula_service_token + values["nebula_api_key"] = nebula_api_key return values @@ -147,7 +135,7 @@ class Nebula(LLM): headers = { "Content-Type": "application/json", - "ApiKey": f"Bearer {self.nebula_service_token}", + "ApiKey": f"{self.nebula_api_key}", } body = {