Update ArgillaCallbackHandler as of latest argilla release (#9043)

Hi @agola11, or whoever is reviewing this PR 😄 

## What's in this PR?

As of the latest Argilla release, we'll change and refactor some things
to make some workflows easier, one of those is how everything's pushed
to Argilla, so that now there's no need to call `push_to_argilla` over a
`FeedbackDataset` when either `push_to_argilla` is called for the first
time, or `from_argilla` is called; among others.

We also add some class variables to make sure those are easy to update
in case we update those internally in the future, also to make the
`warnings.warn` message lighter from the code view.

P.S. Regarding the Twitter/X mention feel free to do so at either
https://twitter.com/argilla_io or https://twitter.com/alvarobartt, or
both if applicable, otherwise, just the first Twitter/X handle.
This commit is contained in:
Alvaro Bartolome 2023-08-10 19:59:46 +02:00 committed by GitHub
parent 8d351bfc20
commit 08a0741d82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,6 +2,8 @@ import os
import warnings import warnings
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from packaging.version import parse
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -51,6 +53,13 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
"Argilla, no doubt about it." "Argilla, no doubt about it."
""" """
REPO_URL = "https://github.com/argilla-io/argilla"
ISSUES_URL = f"{REPO_URL}/issues"
BLOG_URL = "https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html" # noqa: E501
DEFAULT_API_URL = "http://localhost:6900"
DEFAULT_API_KEY = "argilla.apikey"
def __init__( def __init__(
self, self,
dataset_name: str, dataset_name: str,
@ -58,23 +67,22 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
api_url: Optional[str] = None, api_url: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
) -> None: ) -> None:
"""Initializes the `ArgillaCallbackHandler`. f"""Initializes the `ArgillaCallbackHandler`.
Args: Args:
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
exist in advance. If you need help on how to create a `FeedbackDataset` exist in advance. If you need help on how to create a `FeedbackDataset`
in Argilla, please visit in Argilla, please visit {self.BLOG_URL}.
https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html.
workspace_name: name of the workspace in Argilla where the specified workspace_name: name of the workspace in Argilla where the specified
`FeedbackDataset` lives in. Defaults to `None`, which means that the `FeedbackDataset` lives in. Defaults to `None`, which means that the
default workspace will be used. default workspace will be used.
api_url: URL of the Argilla Server that we want to use, and where the api_url: URL of the Argilla Server that we want to use, and where the
`FeedbackDataset` lives in. Defaults to `None`, which means that either `FeedbackDataset` lives in. Defaults to `None`, which means that either
`ARGILLA_API_URL` environment variable or the default `ARGILLA_API_URL` environment variable or `{self.DEFAULT_API_URL}` will
http://localhost:6900 will be used. be used.
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
means that either `ARGILLA_API_KEY` environment variable or the default means that either `ARGILLA_API_KEY` environment variable or the default
`argilla.apikey` will be used. `{self.DEFAULT_API_KEY}` will be used.
Raises: Raises:
ImportError: if the `argilla` package is not installed. ImportError: if the `argilla` package is not installed.
@ -87,41 +95,47 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
# Import Argilla (not via `import_argilla` to keep hints in IDEs) # Import Argilla (not via `import_argilla` to keep hints in IDEs)
try: try:
import argilla as rg # noqa: F401 import argilla as rg # noqa: F401
self.ARGILLA_VERSION = rg.__version__
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"To use the Argilla callback manager you need to have the `argilla` " "To use the Argilla callback manager you need to have the `argilla` "
"Python package installed. Please install it with `pip install argilla`" "Python package installed. Please install it with `pip install argilla`"
) )
# Check whether the Argilla version is compatible
if parse(self.ARGILLA_VERSION) < parse("1.8.0"):
raise ImportError(
f"The installed `argilla` version is {self.ARGILLA_VERSION} but "
"`ArgillaCallbackHandler` requires at least version 1.8.0. Please "
"upgrade `argilla` with `pip install --upgrade argilla`."
)
# Show a warning message if Argilla will assume the default values will be used # Show a warning message if Argilla will assume the default values will be used
if api_url is None and os.getenv("ARGILLA_API_URL") is None: if api_url is None and os.getenv("ARGILLA_API_URL") is None:
warnings.warn( warnings.warn(
( (
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not" "Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
" set, it will default to `http://localhost:6900`." f" set, it will default to `{self.DEFAULT_API_URL}`."
), ),
) )
if api_key is None and os.getenv("ARGILLA_API_KEY") is None: if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
warnings.warn( warnings.warn(
( (
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not" "Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
" set, it will default to `argilla.apikey`." f" set, it will default to `{self.DEFAULT_API_KEY}`."
), ),
) )
# Connect to Argilla with the provided credentials, if applicable # Connect to Argilla with the provided credentials, if applicable
try: try:
rg.init( rg.init(api_key=api_key, api_url=api_url)
api_key=api_key,
api_url=api_url,
)
except Exception as e: except Exception as e:
raise ConnectionError( raise ConnectionError(
f"Could not connect to Argilla with exception: '{e}'.\n" f"Could not connect to Argilla with exception: '{e}'.\n"
"Please check your `api_key` and `api_url`, and make sure that " "Please check your `api_key` and `api_url`, and make sure that "
"the Argilla server is up and running. If the problem persists " "the Argilla server is up and running. If the problem persists "
"please report it to https://github.com/argilla-io/argilla/issues " f"please report it to {self.ISSUES_URL} as an `integration` issue."
"with the label `langchain`."
) from e ) from e
# Set the Argilla variables # Set the Argilla variables
@ -130,46 +144,47 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
# Retrieve the `FeedbackDataset` from Argilla (without existing records) # Retrieve the `FeedbackDataset` from Argilla (without existing records)
try: try:
extra_args = {}
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
warnings.warn(
f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or"
" higher is recommended.",
UserWarning,
)
extra_args = {"with_records": False}
self.dataset = rg.FeedbackDataset.from_argilla( self.dataset = rg.FeedbackDataset.from_argilla(
name=self.dataset_name, name=self.dataset_name,
workspace=self.workspace_name, workspace=self.workspace_name,
with_records=False, **extra_args,
) )
except Exception as e: except Exception as e:
raise FileNotFoundError( raise FileNotFoundError(
"`FeedbackDataset` retrieval from Argilla failed with exception:" f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`."
f" '{e}'.\nPlease check that the dataset with" f"\nPlease check that the dataset with name={self.dataset_name} in the"
f" name={self.dataset_name} in the"
f" workspace={self.workspace_name} exists in advance. If you need help" f" workspace={self.workspace_name} exists in advance. If you need help"
" on how to create a `langchain`-compatible `FeedbackDataset` in" " on how to create a `langchain`-compatible `FeedbackDataset` in"
" Argilla, please visit" f" Argilla, please visit {self.BLOG_URL}. If the problem persists"
" https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html." # noqa: E501 f" please report it to {self.ISSUES_URL} as an `integration` issue."
" If the problem persists please report it to"
" https://github.com/argilla-io/argilla/issues with the label"
" `langchain`."
) from e ) from e
supported_fields = ["prompt", "response"] supported_fields = ["prompt", "response"]
if supported_fields != [field.name for field in self.dataset.fields]: if supported_fields != [field.name for field in self.dataset.fields]:
raise ValueError( raise ValueError(
f"`FeedbackDataset` with name={self.dataset_name} in the" f"`FeedbackDataset` with name={self.dataset_name} in the workspace="
f" workspace={self.workspace_name} " f"{self.workspace_name} had fields that are not supported yet for the"
"had fields that are not supported yet for the `langchain` integration." f"`langchain` integration. Supported fields are: {supported_fields},"
" Supported fields are: " f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501
f"{supported_fields}, and the current `FeedbackDataset` fields are" " For more information on how to create a `langchain`-compatible"
f" {[field.name for field in self.dataset.fields]}. " f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}."
"For more information on how to create a `langchain`-compatible"
" `FeedbackDataset` in Argilla, please visit"
" https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html." # noqa: E501
) )
self.prompts: Dict[str, List[str]] = {} self.prompts: Dict[str, List[str]] = {}
warnings.warn( warnings.warn(
( (
"The `ArgillaCallbackHandler` is currently in beta and is subject to " "The `ArgillaCallbackHandler` is currently in beta and is subject to"
"change based on updates to `langchain`. Please report any issues to " " change based on updates to `langchain`. Please report any issues to"
"https://github.com/argilla-io/argilla/issues with the tag `langchain`." f" {self.ISSUES_URL} as an `integration` issue."
), ),
) )
@ -205,12 +220,13 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
] ]
) )
# Push the records to Argilla
self.dataset.push_to_argilla()
# Pop current run from `self.runs` # Pop current run from `self.runs`
self.prompts.pop(str(kwargs["run_id"])) self.prompts.pop(str(kwargs["run_id"]))
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
# Push the records to Argilla
self.dataset.push_to_argilla()
def on_llm_error( def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
@ -278,15 +294,16 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
] ]
) )
# Push the records to Argilla
self.dataset.push_to_argilla()
# Pop current run from `self.runs` # Pop current run from `self.runs`
if str(kwargs["parent_run_id"]) in self.prompts: if str(kwargs["parent_run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["parent_run_id"])) self.prompts.pop(str(kwargs["parent_run_id"]))
if str(kwargs["run_id"]) in self.prompts: if str(kwargs["run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["run_id"])) self.prompts.pop(str(kwargs["run_id"]))
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
# Push the records to Argilla
self.dataset.push_to_argilla()
def on_chain_error( def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None: