From d1d691caa475a2e21688250d047a4bbcb31a40c1 Mon Sep 17 00:00:00 2001 From: Rithwik Ediga Lakhamsani <81988348+rithwik-db@users.noreply.github.com> Date: Tue, 25 Jul 2023 18:23:54 -0700 Subject: [PATCH] Added Databricks support to MLflow Callback (#7906) Added a quick check to make integration easier with Databricks; another option would be to make a new class, but this seemed more straightfoward. cc: @liangz1 Can this be done in a more straightfoward way? --- .../langchain/callbacks/mlflow_callback.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/libs/langchain/langchain/callbacks/mlflow_callback.py b/libs/langchain/langchain/callbacks/mlflow_callback.py index 553404b07f5..baa297af544 100644 --- a/libs/langchain/langchain/callbacks/mlflow_callback.py +++ b/libs/langchain/langchain/callbacks/mlflow_callback.py @@ -1,3 +1,4 @@ +import os import random import string import tempfile @@ -127,22 +128,27 @@ class MlflowLogger: def __init__(self, **kwargs: Any): self.mlflow = import_mlflow() - tracking_uri = get_from_dict_or_env( - kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", "" - ) - self.mlflow.set_tracking_uri(tracking_uri) - - # User can set other env variables described here - # > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server - - experiment_name = get_from_dict_or_env( - kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME" - ) - self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name) - if self.mlf_exp is not None: - self.mlf_expid = self.mlf_exp.experiment_id + if "DATABRICKS_RUNTIME_VERSION" in os.environ: + self.mlflow.set_tracking_uri("databricks") + self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id() + self.mlf_exp = self.mlflow.get_experiment(self.mlf_expid) else: - self.mlf_expid = self.mlflow.create_experiment(experiment_name) + tracking_uri = get_from_dict_or_env( + kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", "" + ) + self.mlflow.set_tracking_uri(tracking_uri) + + # User can set other env variables described here + # > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server + + experiment_name = get_from_dict_or_env( + kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME" + ) + self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name) + if self.mlf_exp is not None: + self.mlf_expid = self.mlf_exp.experiment_id + else: + self.mlf_expid = self.mlflow.create_experiment(experiment_name) self.start_run(kwargs["run_name"], kwargs["run_tags"])