Run a JAXJob

Run a Kueue scheduled JAXJob

This page shows how to leverage Kueue’s scheduling and resource management capabilities when running Trainer JAXJobs.

This guide is for batch users that have a basic understanding of Kueue. For more information, see Kueue’s overview.

Before you begin

Check administer cluster quotas for details on the initial cluster setup.

Check the Trainer installation guide.

Note that the minimum requirement trainer version is v1.9.0.

You can modify kueue configurations from installed releases to include JAXJobs as an allowed workload.

JAXJob definition

a. Queue selection

The target local queue should be specified in the metadata.labels section of the JAXJob configuration.

metadata:
  labels:
    kueue.x-k8s.io/queue-name: user-queue

b. Optionally set Suspend field in JAXJobs

spec:
  runPolicy:
    suspend: true

By default, Kueue will set suspend to true via webhook and unsuspend it when the JAXJob is admitted.

Sample JAXJob

This example is based on https://github.com/kubeflow/trainer/blob/da11d1116c29322c481d0b8f174df8d6f05004aa/examples/jax/cpu-demo/demo.yaml.

apiVersion: kubeflow.org/v1
kind: JAXJob
metadata:
  name: jax-simple
  namespace: default
  labels:
    kueue.x-k8s.io/queue-name: user-queue
spec:
  jaxReplicaSpecs:
    Worker:
      replicas: 2
      restartPolicy: OnFailure
      template:
        spec:
          containers:
            - name: jax
              image: docker.io/kubeflow/jaxjob-simple:latest
              command:
                - "python3"
                - "train.py"
              imagePullPolicy: Always
              resources:
                requests:
                  cpu: 1
                  memory: "200Mi"