mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
Fixes to the Nebula LLM Integration (#8918)
This addresses some issues with introducing the Nebula LLM to LangChain in this PR: https://github.com/langchain-ai/langchain/pull/8876 This fixes the following: - Removes `SYMBLAI` from variable names - Fixes bug with `Bearer` for the API KEY Thanks again in advance for your help! cc: @hwchase17, @baskaryan --------- Co-authored-by: dvonthenen <david.vonthenen@gmail.com>
This commit is contained in:
parent
d1e305028f
commit
bf4a112aa6
@ -34,9 +34,9 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.environ[\"SYMBLAI_NEBULA_SERVICE_URL\"] = SYMBLAI_NEBULA_SERVICE_URL\n",
|
"os.environ[\"NEBULA_SERVICE_URL\"] = NEBULA_SERVICE_URL\n",
|
||||||
"os.environ[\"SYMBLAI_NEBULA_SERVICE_PATH\"] = SYMBLAI_NEBULA_SERVICE_PATH\n",
|
"os.environ[\"NEBULA_SERVICE_PATH\"] = NEBULA_SERVICE_PATH\n",
|
||||||
"os.environ[\"SYMBLAI_NEBULA_SERVICE_TOKEN\"] = SYMBLAI_NEBULA_SERVICE_TOKEN"
|
"os.environ[\"NEBULA_SERVICE_API_KEY\"] = NEBULA_SERVICE_API_KEY"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -9,8 +9,8 @@ from langchain.llms.base import LLM
|
|||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
DEFAULT_SYMBLAI_NEBULA_SERVICE_URL = "https://api-nebula.symbl.ai"
|
DEFAULT_NEBULA_SERVICE_URL = "https://api-nebula.symbl.ai"
|
||||||
DEFAULT_SYMBLAI_NEBULA_SERVICE_PATH = "/v1/model/generate"
|
DEFAULT_NEBULA_SERVICE_PATH = "/v1/model/generate"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -18,8 +18,8 @@ logger = logging.getLogger(__name__)
|
|||||||
class Nebula(LLM):
|
class Nebula(LLM):
|
||||||
"""Nebula Service models.
|
"""Nebula Service models.
|
||||||
|
|
||||||
To use, you should have the environment variable ``SYMBLAI_NEBULA_SERVICE_URL``,
|
To use, you should have the environment variable ``NEBULA_SERVICE_URL``,
|
||||||
``SYMBLAI_NEBULA_SERVICE_PATH`` and ``SYMBLAI_NEBULA_SERVICE_TOKEN`` set with your Nebula
|
``NEBULA_SERVICE_PATH`` and ``NEBULA_SERVICE_API_KEY`` set with your Nebula
|
||||||
Service, or pass it as a named parameter to the constructor.
|
Service, or pass it as a named parameter to the constructor.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -30,21 +30,8 @@ class Nebula(LLM):
|
|||||||
nebula = Nebula(
|
nebula = Nebula(
|
||||||
nebula_service_url="SERVICE_URL",
|
nebula_service_url="SERVICE_URL",
|
||||||
nebula_service_path="SERVICE_ROUTE",
|
nebula_service_path="SERVICE_ROUTE",
|
||||||
nebula_service_token="SERVICE_TOKEN",
|
nebula_api_key="SERVICE_TOKEN",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use Ray for distributed processing
|
|
||||||
import ray
|
|
||||||
|
|
||||||
prompt_list=[]
|
|
||||||
|
|
||||||
@ray.remote
|
|
||||||
def send_query(llm, prompt):
|
|
||||||
resp = llm(prompt)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
futures = [send_query.remote(nebula, prompt) for prompt in prompt_list]
|
|
||||||
results = ray.get(futures)
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
"""Key/value arguments to pass to the model. Reserved for future use"""
|
"""Key/value arguments to pass to the model. Reserved for future use"""
|
||||||
@ -53,7 +40,7 @@ class Nebula(LLM):
|
|||||||
"""Optional"""
|
"""Optional"""
|
||||||
nebula_service_url: Optional[str] = None
|
nebula_service_url: Optional[str] = None
|
||||||
nebula_service_path: Optional[str] = None
|
nebula_service_path: Optional[str] = None
|
||||||
nebula_service_token: Optional[str] = None
|
nebula_api_key: Optional[str] = None
|
||||||
conversation: str = ""
|
conversation: str = ""
|
||||||
return_scores: Optional[str] = "false"
|
return_scores: Optional[str] = "false"
|
||||||
max_new_tokens: Optional[int] = 2048
|
max_new_tokens: Optional[int] = 2048
|
||||||
@ -69,20 +56,21 @@ class Nebula(LLM):
|
|||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
nebula_service_url = get_from_dict_or_env(
|
nebula_service_url = get_from_dict_or_env(
|
||||||
values, "nebula_service_url", "SYMBLAI_NEBULA_SERVICE_URL"
|
values,
|
||||||
|
"nebula_service_url",
|
||||||
|
"NEBULA_SERVICE_URL",
|
||||||
|
DEFAULT_NEBULA_SERVICE_URL,
|
||||||
)
|
)
|
||||||
nebula_service_path = get_from_dict_or_env(
|
nebula_service_path = get_from_dict_or_env(
|
||||||
values, "nebula_service_path", "SYMBLAI_NEBULA_SERVICE_PATH"
|
values,
|
||||||
|
"nebula_service_path",
|
||||||
|
"NEBULA_SERVICE_PATH",
|
||||||
|
DEFAULT_NEBULA_SERVICE_PATH,
|
||||||
)
|
)
|
||||||
nebula_service_token = get_from_dict_or_env(
|
nebula_api_key = get_from_dict_or_env(
|
||||||
values, "nebula_service_token", "SYMBLAI_NEBULA_SERVICE_TOKEN"
|
values, "nebula_api_key", "NEBULA_SERVICE_API_KEY", ""
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(nebula_service_url) == 0:
|
|
||||||
nebula_service_url = DEFAULT_SYMBLAI_NEBULA_SERVICE_URL
|
|
||||||
if len(nebula_service_path) == 0:
|
|
||||||
nebula_service_path = DEFAULT_SYMBLAI_NEBULA_SERVICE_PATH
|
|
||||||
|
|
||||||
if nebula_service_url.endswith("/"):
|
if nebula_service_url.endswith("/"):
|
||||||
nebula_service_url = nebula_service_url[:-1]
|
nebula_service_url = nebula_service_url[:-1]
|
||||||
if not nebula_service_path.startswith("/"):
|
if not nebula_service_path.startswith("/"):
|
||||||
@ -94,7 +82,7 @@ class Nebula(LLM):
|
|||||||
nebula_service_endpoint = f"{nebula_service_url}{nebula_service_path}"
|
nebula_service_endpoint = f"{nebula_service_url}{nebula_service_path}"
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"ApiKey": f"Bearer {nebula_service_token}",
|
"ApiKey": "{nebula_api_key}",
|
||||||
}
|
}
|
||||||
requests.get(nebula_service_endpoint, headers=headers)
|
requests.get(nebula_service_endpoint, headers=headers)
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
@ -103,7 +91,7 @@ class Nebula(LLM):
|
|||||||
|
|
||||||
values["nebula_service_url"] = nebula_service_url
|
values["nebula_service_url"] = nebula_service_url
|
||||||
values["nebula_service_path"] = nebula_service_path
|
values["nebula_service_path"] = nebula_service_path
|
||||||
values["nebula_service_token"] = nebula_service_token
|
values["nebula_api_key"] = nebula_api_key
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@ -147,7 +135,7 @@ class Nebula(LLM):
|
|||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"ApiKey": f"Bearer {self.nebula_service_token}",
|
"ApiKey": f"{self.nebula_api_key}",
|
||||||
}
|
}
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
|
Loading…
Reference in New Issue
Block a user