mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-28 14:05:02 +00:00
Added Databricks support to MLflow Callback (#7906)
Added a quick check to make integration easier with Databricks; another option would be to make a new class, but this seemed more straightfoward. cc: @liangz1 Can this be done in a more straightfoward way?
This commit is contained in:
parent
479cc086ba
commit
d1d691caa4
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -127,22 +128,27 @@ class MlflowLogger:
|
|||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
def __init__(self, **kwargs: Any):
|
||||||
self.mlflow = import_mlflow()
|
self.mlflow = import_mlflow()
|
||||||
tracking_uri = get_from_dict_or_env(
|
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
|
||||||
kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", ""
|
self.mlflow.set_tracking_uri("databricks")
|
||||||
)
|
self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id()
|
||||||
self.mlflow.set_tracking_uri(tracking_uri)
|
self.mlf_exp = self.mlflow.get_experiment(self.mlf_expid)
|
||||||
|
|
||||||
# User can set other env variables described here
|
|
||||||
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
|
|
||||||
|
|
||||||
experiment_name = get_from_dict_or_env(
|
|
||||||
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
|
|
||||||
)
|
|
||||||
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
|
|
||||||
if self.mlf_exp is not None:
|
|
||||||
self.mlf_expid = self.mlf_exp.experiment_id
|
|
||||||
else:
|
else:
|
||||||
self.mlf_expid = self.mlflow.create_experiment(experiment_name)
|
tracking_uri = get_from_dict_or_env(
|
||||||
|
kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", ""
|
||||||
|
)
|
||||||
|
self.mlflow.set_tracking_uri(tracking_uri)
|
||||||
|
|
||||||
|
# User can set other env variables described here
|
||||||
|
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
|
||||||
|
|
||||||
|
experiment_name = get_from_dict_or_env(
|
||||||
|
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
|
||||||
|
)
|
||||||
|
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
|
||||||
|
if self.mlf_exp is not None:
|
||||||
|
self.mlf_expid = self.mlf_exp.experiment_id
|
||||||
|
else:
|
||||||
|
self.mlf_expid = self.mlflow.create_experiment(experiment_name)
|
||||||
|
|
||||||
self.start_run(kwargs["run_name"], kwargs["run_tags"])
|
self.start_run(kwargs["run_name"], kwargs["run_tags"])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user