Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/rayjob_e2e_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ jobs:
kubectl create clusterrolebinding sdk-user-secret-manager --clusterrole=secret-manager --user=sdk-user
kubectl create clusterrole workload-reader --verb=get,list,watch --resource=workloads
kubectl create clusterrolebinding sdk-user-workload-reader --clusterrole=workload-reader --user=sdk-user
kubectl create clusterrole workloadpriorityclass-reader --verb=get,list --resource=workloadpriorityclasses
kubectl create clusterrolebinding sdk-user-workloadpriorityclass-reader --clusterrole=workloadpriorityclass-reader --user=sdk-user
kubectl config use-context sdk-user

- name: Run RayJob E2E tests
Expand Down
47 changes: 47 additions & 0 deletions src/codeflare_sdk/common/kueue/kueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# limitations under the License.

from typing import Optional, List
import logging
from codeflare_sdk.common import _kube_api_error_handling
from codeflare_sdk.common.kubernetes_cluster.auth import config_check, get_api_client
from kubernetes import client
from kubernetes.client.exceptions import ApiException

from ...common.utils import get_current_namespace

logger = logging.getLogger(__name__)


def get_default_kueue_name(namespace: str) -> Optional[str]:
"""
Expand Down Expand Up @@ -144,6 +147,50 @@ def local_queue_exists(namespace: str, local_queue_name: str) -> bool:
return False


def priority_class_exists(priority_class_name: str) -> Optional[bool]:
"""
Checks if a WorkloadPriorityClass with the provided name exists in the cluster.

WorkloadPriorityClass is a cluster-scoped resource.

Args:
priority_class_name (str):
The name of the WorkloadPriorityClass to check for existence.

Returns:
Optional[bool]:
True if the WorkloadPriorityClass exists, False if it doesn't exist,
None if we cannot verify (e.g., permission denied).
"""
try:
config_check()
api_instance = client.CustomObjectsApi(get_api_client())
# Try to get the specific WorkloadPriorityClass by name
api_instance.get_cluster_custom_object(
group="kueue.x-k8s.io",
version="v1beta1",
plural="workloadpriorityclasses",
name=priority_class_name,
)
return True
except client.ApiException as e:
if e.status == 404:
return False

logger.warning(
f"Error checking WorkloadPriorityClass '{priority_class_name}': {e.reason}. "
f"Cannot verify if it exists."
)
return None

except Exception as e:
logger.warning(
f"Unexpected error checking WorkloadPriorityClass '{priority_class_name}': {str(e)}. "
f"Cannot verify if it exists."
)
return None


def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]):
"""
Adds a local queue name label to the provided item.
Expand Down
53 changes: 52 additions & 1 deletion src/codeflare_sdk/common/kueue/test_kueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
import os
import filecmp
from pathlib import Path
from .kueue import list_local_queues, local_queue_exists, add_queue_label
from .kueue import (
list_local_queues,
local_queue_exists,
add_queue_label,
priority_class_exists,
)

parent = Path(__file__).resolve().parents[4] # project directory
aw_dir = os.path.expanduser("~/.codeflare/resources/")
Expand Down Expand Up @@ -292,6 +297,52 @@ def test_add_queue_label_with_invalid_local_queue(mocker):
add_queue_label(item, namespace, local_queue)


def test_priority_class_exists_found(mocker):
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
mock_api = mocker.patch("kubernetes.client.CustomObjectsApi")
mock_api.return_value.get_cluster_custom_object.return_value = {
"metadata": {"name": "high-priority"}
}

assert priority_class_exists("high-priority") is True


def test_priority_class_exists_not_found(mocker):
from kubernetes.client import ApiException

mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
mock_api = mocker.patch("kubernetes.client.CustomObjectsApi")
mock_api.return_value.get_cluster_custom_object.side_effect = ApiException(
status=404
)

assert priority_class_exists("missing-priority") is False


def test_priority_class_exists_permission_denied(mocker):
from kubernetes.client import ApiException

mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
mock_api = mocker.patch("kubernetes.client.CustomObjectsApi")
mock_api.return_value.get_cluster_custom_object.side_effect = ApiException(
status=403
)

assert priority_class_exists("some-priority") is None


def test_priority_class_exists_other_error(mocker):
from kubernetes.client import ApiException

mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
mock_api = mocker.patch("kubernetes.client.CustomObjectsApi")
mock_api.return_value.get_cluster_custom_object.side_effect = ApiException(
status=500
)

assert priority_class_exists("some-priority") is None


# Make sure to always keep this function last
def test_cleanup():
os.remove(f"{aw_dir}unit-test-cluster-kueue.yaml")
Expand Down
61 changes: 54 additions & 7 deletions src/codeflare_sdk/ray/rayjobs/rayjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from typing import Dict, Any, Optional, Tuple, Union

from ray.runtime_env import RuntimeEnv
from codeflare_sdk.common.kueue.kueue import get_default_kueue_name
from codeflare_sdk.common.kueue.kueue import (
get_default_kueue_name,
priority_class_exists,
)
from codeflare_sdk.common.utils.constants import MOUNT_PATH

from codeflare_sdk.common.utils.utils import get_ray_image_for_python_version
Expand Down Expand Up @@ -69,6 +72,7 @@ def __init__(
ttl_seconds_after_finished: int = 0,
active_deadline_seconds: Optional[int] = None,
local_queue: Optional[str] = None,
priority_class: Optional[str] = None,
):
"""
Initialize a RayJob instance.
Expand All @@ -86,11 +90,13 @@ def __init__(
ttl_seconds_after_finished: Seconds to wait before cleanup after job finishes (default: 0)
active_deadline_seconds: Maximum time the job can run before being terminated (optional)
local_queue: The Kueue LocalQueue to submit the job to (optional)
priority_class: The Kueue WorkloadPriorityClass name for preemption control (optional).

Note:
- True if cluster_config is provided (new cluster will be cleaned up)
- False if cluster_name is provided (existing cluster will not be shut down)
- User can explicitly set this value to override auto-detection
- Kueue labels (queue and priority) can be applied to both new and existing clusters
"""
if cluster_name is None and cluster_config is None:
raise ValueError(
Expand Down Expand Up @@ -124,6 +130,7 @@ def __init__(
self.ttl_seconds_after_finished = ttl_seconds_after_finished
self.active_deadline_seconds = active_deadline_seconds
self.local_queue = local_queue
self.priority_class = priority_class

if namespace is None:
detected_namespace = get_current_namespace()
Expand Down Expand Up @@ -165,6 +172,7 @@ def submit(self) -> str:
# Validate configuration before submitting
self._validate_ray_version_compatibility()
self._validate_working_dir_entrypoint()
self._validate_priority_class()

# Extract files from entrypoint and runtime_env working_dir
files = extract_all_local_files(self)
Expand Down Expand Up @@ -243,12 +251,14 @@ def _build_rayjob_cr(self) -> Dict[str, Any]:
# Extract files once and use for both runtime_env and submitter pod
files = extract_all_local_files(self)

# Build Kueue labels - only for new clusters (lifecycled)
labels = {}
# If cluster_config is provided, use the local_queue from the cluster_config

if self._cluster_config is not None:
if self.local_queue:
labels["kueue.x-k8s.io/queue-name"] = self.local_queue
else:
# Auto-detect default queue for new clusters
default_queue = get_default_kueue_name(self.namespace)
if default_queue:
labels["kueue.x-k8s.io/queue-name"] = default_queue
Expand All @@ -262,12 +272,23 @@ def _build_rayjob_cr(self) -> Dict[str, Any]:
f"To fix this, please explicitly specify the 'local_queue' parameter."
)

rayjob_cr["metadata"]["labels"] = labels
if self.priority_class:
labels["kueue.x-k8s.io/priority-class"] = self.priority_class

# When using Kueue (queue label present), start with suspend=true
# Kueue will unsuspend the job once the workload is admitted
if labels.get("kueue.x-k8s.io/queue-name"):
rayjob_cr["spec"]["suspend"] = True
# Apply labels to metadata
if labels:
rayjob_cr["metadata"]["labels"] = labels

# When using Kueue with lifecycled clusters, start with suspend=true
# Kueue will unsuspend the job once the workload is admitted
if labels.get("kueue.x-k8s.io/queue-name"):
rayjob_cr["spec"]["suspend"] = True
else:
if self.local_queue or self.priority_class:
logger.warning(
f"Kueue labels (local_queue, priority_class) are ignored for RayJobs "
f"targeting existing clusters. Kueue only manages RayJobs that create new clusters."
)

# Add active deadline if specified
if self.active_deadline_seconds:
Expand Down Expand Up @@ -450,6 +471,32 @@ def _validate_cluster_config_image(self):
elif is_warning:
warnings.warn(f"Cluster config image: {message}")

def _validate_priority_class(self):
"""
Validate that the priority class exists in the cluster (best effort).

Raises ValueError if the priority class is definitively known not to exist.
If we cannot verify (e.g., permission denied), logs a warning and allows submission.
"""
if self.priority_class:
logger.debug(f"Validating priority class '{self.priority_class}'...")
exists = priority_class_exists(self.priority_class)

if exists is False:
# Definitively doesn't exist - fail validation
raise ValueError(
f"Priority class '{self.priority_class}' does not exist"
)
elif exists is None:
# Cannot verify - log warning and allow submission
logger.warning(
f"Could not verify if priority class '{self.priority_class}' exists. "
f"Proceeding with submission - Kueue will validate on admission."
)
else:
# exists is True - validation passed
logger.debug(f"Priority class '{self.priority_class}' verified.")

def _validate_working_dir_entrypoint(self):
"""
Validate entrypoint file configuration.
Expand Down
Loading
Loading