mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 21:11:43 +00:00
Improvements in Nebula LLM (#9226)
- Description: Added improvements in Nebula LLM to perform auto-retry; more generation parameters supported. Conversation is no longer required to be passed in the LLM object. Examples are updated. - Issue: N/A - Dependencies: N/A - Tag maintainer: @baskaryan - Twitter handle: symbldotai --------- Co-authored-by: toshishjawale <toshish@symbl.ai>
This commit is contained in:
@@ -1,8 +1,17 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
from pydantic_v1 import Extra, root_validator
|
||||
from pydantic import Extra, root_validator
|
||||
from requests import ConnectTimeout, ReadTimeout, RequestException
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
@@ -19,7 +28,7 @@ class Nebula(LLM):
|
||||
"""Nebula Service models.
|
||||
|
||||
To use, you should have the environment variable ``NEBULA_SERVICE_URL``,
|
||||
``NEBULA_SERVICE_PATH`` and ``NEBULA_SERVICE_API_KEY`` set with your Nebula
|
||||
``NEBULA_SERVICE_PATH`` and ``NEBULA_API_KEY`` set with your Nebula
|
||||
Service, or pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
@@ -28,9 +37,9 @@ class Nebula(LLM):
|
||||
from langchain.llms import Nebula
|
||||
|
||||
nebula = Nebula(
|
||||
nebula_service_url="SERVICE_URL",
|
||||
nebula_service_path="SERVICE_ROUTE",
|
||||
nebula_api_key="SERVICE_TOKEN",
|
||||
nebula_service_url="NEBULA_SERVICE_URL",
|
||||
nebula_service_path="NEBULA_SERVICE_PATH",
|
||||
nebula_api_key="NEBULA_API_KEY",
|
||||
)
|
||||
""" # noqa: E501
|
||||
|
||||
@@ -38,14 +47,19 @@ class Nebula(LLM):
|
||||
model_kwargs: Optional[dict] = None
|
||||
|
||||
"""Optional"""
|
||||
|
||||
nebula_service_url: Optional[str] = None
|
||||
nebula_service_path: Optional[str] = None
|
||||
nebula_api_key: Optional[str] = None
|
||||
conversation: str = ""
|
||||
return_scores: Optional[str] = "false"
|
||||
max_new_tokens: Optional[int] = 2048
|
||||
top_k: Optional[float] = 2
|
||||
penalty_alpha: Optional[float] = 0.1
|
||||
model: Optional[str] = None
|
||||
max_new_tokens: Optional[int] = 128
|
||||
temperature: Optional[float] = 0.6
|
||||
top_p: Optional[float] = 0.95
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
top_k: Optional[int] = 0
|
||||
penalty_alpha: Optional[float] = 0.0
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
max_retries: Optional[int] = 10
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -68,7 +82,7 @@ class Nebula(LLM):
|
||||
DEFAULT_NEBULA_SERVICE_PATH,
|
||||
)
|
||||
nebula_api_key = get_from_dict_or_env(
|
||||
values, "nebula_api_key", "NEBULA_SERVICE_API_KEY", ""
|
||||
values, "nebula_api_key", "NEBULA_API_KEY", None
|
||||
)
|
||||
|
||||
if nebula_service_url.endswith("/"):
|
||||
@@ -76,25 +90,24 @@ class Nebula(LLM):
|
||||
if not nebula_service_path.startswith("/"):
|
||||
nebula_service_path = "/" + nebula_service_path
|
||||
|
||||
""" TODO: Future login"""
|
||||
"""
|
||||
try:
|
||||
nebula_service_endpoint = f"{nebula_service_url}{nebula_service_path}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"ApiKey": "{nebula_api_key}",
|
||||
}
|
||||
requests.get(nebula_service_endpoint, headers=headers)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(e)
|
||||
"""
|
||||
|
||||
values["nebula_service_url"] = nebula_service_url
|
||||
values["nebula_service_path"] = nebula_service_path
|
||||
values["nebula_api_key"] = nebula_api_key
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Cohere API."""
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"penalty_alpha": self.penalty_alpha,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
@@ -103,7 +116,6 @@ class Nebula(LLM):
|
||||
"nebula_service_url": self.nebula_service_url,
|
||||
"nebula_service_path": self.nebula_service_path,
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
"conversation": self.conversation,
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -111,6 +123,25 @@ class Nebula(LLM):
|
||||
"""Return type of llm."""
|
||||
return "nebula"
|
||||
|
||||
def _invocation_params(
|
||||
self, stop_sequences: Optional[List[str]], **kwargs: Any
|
||||
) -> dict:
|
||||
params = self._default_params
|
||||
if self.stop_sequences is not None and stop_sequences is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop_sequences is not None:
|
||||
params["stop_sequences"] = self.stop_sequences
|
||||
else:
|
||||
params["stop_sequences"] = stop_sequences
|
||||
return {**params, **kwargs}
|
||||
|
||||
@staticmethod
|
||||
def _process_response(response: Any, stop: Optional[List[str]]) -> str:
|
||||
text = response["output"]["text"]
|
||||
if stop:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -128,57 +159,84 @@ class Nebula(LLM):
|
||||
.. code-block:: python
|
||||
response = nebula("Tell me a joke.")
|
||||
"""
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
prompt = prompt.strip()
|
||||
if "\n" in prompt:
|
||||
instruction = prompt.split("\n")[0]
|
||||
conversation = "\n".join(prompt.split("\n")[1:])
|
||||
else:
|
||||
raise ValueError("Prompt must contain instruction and conversation.")
|
||||
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
response = completion_with_retry(
|
||||
self,
|
||||
instruction=instruction,
|
||||
conversation=conversation,
|
||||
params=params,
|
||||
url=f"{self.nebula_service_url}{self.nebula_service_path}",
|
||||
)
|
||||
_stop = params.get("stop_sequences")
|
||||
return self._process_response(response, _stop)
|
||||
|
||||
nebula_service_endpoint = f"{self.nebula_service_url}{self.nebula_service_path}"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"ApiKey": f"{self.nebula_api_key}",
|
||||
def make_request(
|
||||
self: Nebula,
|
||||
instruction: str,
|
||||
conversation: str,
|
||||
url: str = f"{DEFAULT_NEBULA_SERVICE_URL}{DEFAULT_NEBULA_SERVICE_PATH}",
|
||||
params: Dict = {},
|
||||
) -> Any:
|
||||
"""Generate text from the model."""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"ApiKey": f"{self.nebula_api_key}",
|
||||
}
|
||||
|
||||
body = {
|
||||
"prompt": {
|
||||
"instruction": instruction,
|
||||
"conversation": {"text": f"{conversation}"},
|
||||
}
|
||||
}
|
||||
|
||||
body = {
|
||||
"prompt": {
|
||||
"instruction": prompt,
|
||||
"conversation": {"text": f"{self.conversation}"},
|
||||
},
|
||||
"return_scores": self.return_scores,
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"top_k": self.top_k,
|
||||
"penalty_alpha": self.penalty_alpha,
|
||||
}
|
||||
# add params to body
|
||||
for key, value in params.items():
|
||||
body[key] = value
|
||||
|
||||
if len(self.conversation) == 0:
|
||||
raise ValueError("Error conversation is empty.")
|
||||
# make request
|
||||
response = requests.post(url, headers=headers, json=body)
|
||||
|
||||
logger.debug(f"NEBULA _model_kwargs: {_model_kwargs}")
|
||||
logger.debug(f"NEBULA body: {body}")
|
||||
logger.debug(f"NEBULA kwargs: {kwargs}")
|
||||
logger.debug(f"NEBULA conversation: {self.conversation}")
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request failed with status code {response.status_code}"
|
||||
f" and message {response.text}"
|
||||
)
|
||||
|
||||
# call API
|
||||
try:
|
||||
response = requests.post(
|
||||
nebula_service_endpoint, headers=headers, json=body
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
return json.loads(response.text)
|
||||
|
||||
logger.debug(f"NEBULA response: {response}")
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error returned by service, status code {response.status_code}"
|
||||
)
|
||||
def _create_retry_decorator(llm: Nebula) -> Callable[[Any], Any]:
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterward
|
||||
max_retries = llm.max_retries if llm.max_retries is not None else 3
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type((RequestException, ConnectTimeout, ReadTimeout))
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
""" get the result """
|
||||
text = response.text
|
||||
|
||||
""" enforce stop """
|
||||
if stop is not None:
|
||||
# This is required since the stop tokens
|
||||
# are not enforced by the model parameters
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
def completion_with_retry(llm: Nebula, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
|
||||
return text
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**_kwargs: Any) -> Any:
|
||||
return make_request(llm, **_kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
@@ -1,56 +1,46 @@
|
||||
"""Test Nebula API wrapper."""
|
||||
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.llms.symblai_nebula import Nebula
|
||||
|
||||
|
||||
def test_symblai_nebula_call() -> None:
|
||||
"""Test valid call to Nebula."""
|
||||
conversation = """Speaker 1: Thank you for calling ABC, company.Speaker 1: My name
|
||||
is Mary.Speaker 1: How may I help you?Speaker 2: Today?Speaker 1: All right,
|
||||
Madam.Speaker 1: I really apologize for this inconvenient.Speaker 1: I will be happy
|
||||
to assist you in this matter.Speaker 1: Could you please offer me Yuri your account
|
||||
number?Speaker 1: Alright Madam, thank you very much.Speaker 1: Let me check that
|
||||
for confirmation.Speaker 1: Did you say 534 00 365?Speaker 2: 48?Speaker 1: Very good
|
||||
man.Speaker 1: Now for verification purposes, can I please get your full?Speaker
|
||||
2: Name?Speaker 1: Alright, thank you.Speaker 1: Very much.Speaker 1: Madam.Speaker
|
||||
1: Can I, please get your birthdate now?Speaker 1: I am sorry madam.Speaker 1: I
|
||||
didn't make this clear is for verification.Speaker 1: Purposes is the company
|
||||
request.Speaker 1: The system requires me, your name, your complete name and your
|
||||
date of.Speaker 2: Birth.Speaker 2: Alright, thank you very much madam.Speaker 1:
|
||||
All right.Speaker 1: Thank you very much, Madam.Speaker 1: Thank you for that
|
||||
information.Speaker 1: Let me check what happens.Speaker 2: Here.Speaker 1: So
|
||||
according to our data space them, you did pay your last bill last August the 12,
|
||||
which was two days ago in one of our Affiliated payment centers.Speaker 1: So, at the
|
||||
moment you currently, We have zero balance.Speaker 1: So however, the bill that you
|
||||
received was generated a week before you made the pavement, this is reason why you
|
||||
already make this payment, have not been reflected yet.Speaker 1: So what we do in
|
||||
this case, you just simply disregard the amount indicated in the field and you
|
||||
continue to enjoy our service man.Speaker 1: Sure, Madam.Speaker 1: And I am sure
|
||||
you need your cell phone for everything for life, right?Speaker 1: So I really
|
||||
apologize for this inconvenience.Speaker 1: And let me tell you that delays in the
|
||||
bill is usually caused by delays in our Courier Service.Speaker 1: That is to say
|
||||
that it'''s a problem, not with the company, but with a courier service, For a more
|
||||
updated, feel of your account, you can visit our website and log into your account,
|
||||
and they'''re in the system.Speaker 1: On the website, you are going to have the
|
||||
possibility to pay the bill.Speaker 1: That is more.Speaker 2: Updated.Speaker 2:
|
||||
Of course, Madam I can definitely assist you with that.Speaker 2: Once you have,
|
||||
you want to see your bill updated, please go to www.hsn BC campus, any.com after
|
||||
that.Speaker 2: You will see in the tale.Speaker 1: All right corner.Speaker 1: So
|
||||
you're going to see a pay now button.Speaker 1: Please click on the pay now button
|
||||
and the serve.Speaker 1: The system is going to ask you for personal
|
||||
information.Speaker 1: Such as your first name, your ID account, your the number of
|
||||
your account, your email address, and your phone number once you complete this personal
|
||||
information."""
|
||||
llm = Nebula(
|
||||
conversation=conversation,
|
||||
)
|
||||
conversation = """Sam: Good morning, team! Let's keep this standup concise.
|
||||
We'll go in the usual order: what you did yesterday,
|
||||
what you plan to do today, and any blockers. Alex, kick us off.
|
||||
Alex: Morning! Yesterday, I wrapped up the UI for the user dashboard.
|
||||
The new charts and widgets are now responsive.
|
||||
I also had a sync with the design team to ensure the final touchups are in
|
||||
line with the brand guidelines. Today, I'll start integrating the frontend with
|
||||
the new API endpoints Rhea was working on.
|
||||
The only blocker is waiting for some final API documentation,
|
||||
but I guess Rhea can update on that.
|
||||
Rhea: Hey, all! Yep, about the API documentation - I completed the majority of
|
||||
the backend work for user data retrieval yesterday.
|
||||
The endpoints are mostly set up, but I need to do a bit more testing today.
|
||||
I'll finalize the API documentation by noon, so that should unblock Alex.
|
||||
After that, I’ll be working on optimizing the database queries
|
||||
for faster data fetching. No other blockers on my end.
|
||||
Sam: Great, thanks Rhea. Do reach out if you need any testing assistance
|
||||
or if there are any hitches with the database.
|
||||
Now, my update: Yesterday, I coordinated with the client to get clarity
|
||||
on some feature requirements. Today, I'll be updating our project roadmap
|
||||
and timelines based on their feedback. Additionally, I'll be sitting with
|
||||
the QA team in the afternoon for preliminary testing.
|
||||
Blocker: I might need both of you to be available for a quick call
|
||||
in case the client wants to discuss the changes live.
|
||||
Alex: Sounds good, Sam. Just let us know a little in advance for the call.
|
||||
Rhea: Agreed. We can make time for that.
|
||||
Sam: Perfect! Let's keep the momentum going. Reach out if there are any
|
||||
sudden issues or support needed. Have a productive day!
|
||||
Alex: You too.
|
||||
Rhea: Thanks, bye!"""
|
||||
llm = Nebula(nebula_api_key="<your_api_key>")
|
||||
|
||||
template = """Identify the {count} main objectives or goals mentioned in this
|
||||
context concisely in less points. Emphasize on key intents."""
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
instruction = """Identify the main objectives mentioned in this
|
||||
conversation."""
|
||||
prompt = PromptTemplate.from_template("{instruction}\n{conversation}")
|
||||
|
||||
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
||||
output = llm_chain.run(count="five")
|
||||
|
||||
output = llm_chain.run(instruction=instruction, conversation=conversation)
|
||||
assert isinstance(output, str)
|
||||
|
Reference in New Issue
Block a user