Skip to content

Dask

This example showcases how to run Dask within a Hera submitted Argo workflow.

from hera.workflows import (
    Steps,
    Workflow,
    script,
)


@script(image="ghcr.io/dask/dask:latest")
def dask_computation(namespace: str = "default", n_workers: int = 1) -> None:
    import subprocess

    # this is required for otherwise the dask distributed and kubernetes clients packages are not included by default
    # ideally, you'd have a package in your organization that takes care of the following cluster details :)
    subprocess.run(
        ["pip", "install", "dask-kubernetes", "dask[distributed]"], stdout=subprocess.PIPE, universal_newlines=True
    )

    import dask.array as da
    from dask.distributed import Client
    from dask_kubernetes.classic import KubeCluster, make_pod_spec

    cluster = KubeCluster(
        pod_template=make_pod_spec(
            image="ghcr.io/dask/dask:latest",
            memory_limit="4G",
            memory_request="2G",
            cpu_limit=1,
            cpu_request=1,
        ),
        namespace=namespace,
        n_workers=n_workers,
    )

    # once the `Client` is initialized all dask calls are actually implicitly performed against it
    client = Client(cluster)
    array = da.ones((1000, 1000, 1000))
    print("Array mean = {array_mean}, expected = 1.0".format(array_mean=array.mean().compute()))
    client.close()


with Workflow(generate_name="dask-", entrypoint="s") as w:
    with Steps(name="s"):
        dask_computation()
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
  generateName: dask-
spec:
  entrypoint: s
  templates:
  - name: s
    steps:
    - - name: dask-computation
        template: dask-computation
  - inputs:
      parameters:
      - default: default
        name: namespace
      - default: '1'
        name: n_workers
    name: dask-computation
    script:
      command:
      - python
      image: ghcr.io/dask/dask:latest
      source: |-
        import os
        import sys
        sys.path.append(os.getcwd())
        import json
        try: n_workers = json.loads(r'''{{inputs.parameters.n_workers}}''')
        except: n_workers = r'''{{inputs.parameters.n_workers}}'''
        try: namespace = json.loads(r'''{{inputs.parameters.namespace}}''')
        except: namespace = r'''{{inputs.parameters.namespace}}'''

        import subprocess
        subprocess.run(['pip', 'install', 'dask-kubernetes', 'dask[distributed]'], stdout=subprocess.PIPE, universal_newlines=True)
        import dask.array as da
        from dask.distributed import Client
        from dask_kubernetes.classic import KubeCluster, make_pod_spec
        cluster = KubeCluster(pod_template=make_pod_spec(image='ghcr.io/dask/dask:latest', memory_limit='4G', memory_request='2G', cpu_limit=1, cpu_request=1), namespace=namespace, n_workers=n_workers)
        client = Client(cluster)
        array = da.ones((1000, 1000, 1000))
        print('Array mean = {array_mean}, expected = 1.0'.format(array_mean=array.mean().compute()))
        client.close()

Comments