import random
from hera.workflows import DAG, SecretEnv, Workflow, script
@script(
# we use the `runner` constructor so that Hera automatically infers the path to our method, and assembles a
# container for us, which will run remotely, on Argo / K8s.
constructor="runner",
# here we set the environment variable that contains the W&B API key
env=SecretEnv(name="WANDB_API_KEY", secret_name="wandb-api-key", secret_key="wandb-api-key"),
)
def train_model(project_name: str, learning_rate: float, architecture: str, dataset: str, epochs: int) -> None:
# TODO: move this import outside of the script. It is added here so that `wandb` does not need to be added as a
# dependency to Hera
import wandb
# start a new wandb run to track this script
wandb.init(
# set the wandb project where this run will be logged
project=project_name,
# track hyperparameters and run metadata
config={
"learning_rate": learning_rate,
"architecture": architecture,
"dataset": dataset,
"epochs": epochs,
},
)
# simulate training
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
# mock accuracy and loss, this is where you'd actually have a model training
acc = 1 - 2**-epoch - random.random() / epoch - offset
loss = 2**-epoch + random.random() / epoch + offset
# log metrics to wandb
wandb.log({"acc": acc, "loss": loss})
# [optional] finish the wandb run, necessary in notebooks
wandb.finish()
with Workflow(name="wandb-ml-monitoring", entrypoint="train") as w:
with DAG(name="train"):
train_model(
# `name` refers to the Argo Workflows task that gets automatically created upon invoking `train_model`
# inside a DAG context. That's because only a Task can be a child of a DAG. Alternatively, you can use
# Steps, in which case these invocation would create a WorkflowStep!
name="train-1",
arguments=dict(
project_name="wandb-ml-monitoring",
learning_rate=1e-1,
architecture="CNN",
dataset="MNIST",
epochs=10,
),
)
# here we train another model with different configurations but the same project name so that wandb records
# training in the same run. Then we can easily compare the runs in the same project to see how our
# hyperparameter choices lead to different model training performance
train_model(
name="train-2",
arguments=dict(
project_name="wandb-ml-monitoring",
learning_rate=1e-2,
architecture="CNN",
dataset="MNIST",
epochs=10,
),
)
# note that you can easily make a hyperparameter search grid by invoking `train_model` inside the loop over
# the results of `itertools.combinations` or `itertools.product`!
train_model(
name="train-3",
arguments=dict(
project_name="wandb-ml-monitoring",
learning_rate=1e-3,
architecture="CNN",
dataset="MNIST",
epochs=10,
),
)