Skip to content

Commit 634a03f

Browse files
committed
RHOAIENG-39073: Add priority class support
1 parent 8eac545 commit 634a03f

File tree

6 files changed

+393
-16
lines changed

6 files changed

+393
-16
lines changed

.github/workflows/rayjob_e2e_tests.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ jobs:
119119
kubectl create clusterrolebinding sdk-user-secret-manager --clusterrole=secret-manager --user=sdk-user
120120
kubectl create clusterrole workload-reader --verb=get,list,watch --resource=workloads
121121
kubectl create clusterrolebinding sdk-user-workload-reader --clusterrole=workload-reader --user=sdk-user
122+
kubectl create clusterrole workloadpriorityclass-reader --verb=get,list --resource=workloadpriorityclasses
123+
kubectl create clusterrolebinding sdk-user-workloadpriorityclass-reader --clusterrole=workloadpriorityclass-reader --user=sdk-user
122124
kubectl config use-context sdk-user
123125
124126
- name: Run RayJob E2E tests

src/codeflare_sdk/common/kueue/kueue.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# limitations under the License.
1414

1515
from typing import Optional, List
16+
import logging
1617
from codeflare_sdk.common import _kube_api_error_handling
1718
from codeflare_sdk.common.kubernetes_cluster.auth import config_check, get_api_client
1819
from kubernetes import client
1920
from kubernetes.client.exceptions import ApiException
2021

2122
from ...common.utils import get_current_namespace
2223

24+
logger = logging.getLogger(__name__)
25+
2326

2427
def get_default_kueue_name(namespace: str) -> Optional[str]:
2528
"""
@@ -144,6 +147,50 @@ def local_queue_exists(namespace: str, local_queue_name: str) -> bool:
144147
return False
145148

146149

150+
def priority_class_exists(priority_class_name: str) -> Optional[bool]:
151+
"""
152+
Checks if a WorkloadPriorityClass with the provided name exists in the cluster.
153+
154+
WorkloadPriorityClass is a cluster-scoped resource.
155+
156+
Args:
157+
priority_class_name (str):
158+
The name of the WorkloadPriorityClass to check for existence.
159+
160+
Returns:
161+
Optional[bool]:
162+
True if the WorkloadPriorityClass exists, False if it doesn't exist,
163+
None if we cannot verify (e.g., permission denied).
164+
"""
165+
try:
166+
config_check()
167+
api_instance = client.CustomObjectsApi(get_api_client())
168+
# Try to get the specific WorkloadPriorityClass by name
169+
api_instance.get_cluster_custom_object(
170+
group="kueue.x-k8s.io",
171+
version="v1beta1",
172+
plural="workloadpriorityclasses",
173+
name=priority_class_name,
174+
)
175+
return True
176+
except client.ApiException as e:
177+
if e.status == 404:
178+
return False
179+
180+
logger.warning(
181+
f"Error checking WorkloadPriorityClass '{priority_class_name}': {e.reason}. "
182+
f"Cannot verify if it exists."
183+
)
184+
return None
185+
186+
except Exception as e:
187+
logger.warning(
188+
f"Unexpected error checking WorkloadPriorityClass '{priority_class_name}': {str(e)}. "
189+
f"Cannot verify if it exists."
190+
)
191+
return None
192+
193+
147194
def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]):
148195
"""
149196
Adds a local queue name label to the provided item.

src/codeflare_sdk/common/kueue/test_kueue.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
import os
2424
import filecmp
2525
from pathlib import Path
26-
from .kueue import list_local_queues, local_queue_exists, add_queue_label
26+
from .kueue import (
27+
list_local_queues,
28+
local_queue_exists,
29+
add_queue_label,
30+
priority_class_exists,
31+
)
2732

2833
parent = Path(__file__).resolve().parents[4] # project directory
2934
aw_dir = os.path.expanduser("~/.codeflare/resources/")
@@ -292,6 +297,52 @@ def test_add_queue_label_with_invalid_local_queue(mocker):
292297
add_queue_label(item, namespace, local_queue)
293298

294299

300+
def test_priority_class_exists_found(mocker):
301+
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
302+
mock_api = mocker.patch("kubernetes.client.CustomObjectsApi")
303+
mock_api.return_value.get_cluster_custom_object.return_value = {
304+
"metadata": {"name": "high-priority"}
305+
}
306+
307+
assert priority_class_exists("high-priority") is True
308+
309+
310+
def test_priority_class_exists_not_found(mocker):
311+
from kubernetes.client import ApiException
312+
313+
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
314+
mock_api = mocker.patch("kubernetes.client.CustomObjectsApi")
315+
mock_api.return_value.get_cluster_custom_object.side_effect = ApiException(
316+
status=404
317+
)
318+
319+
assert priority_class_exists("missing-priority") is False
320+
321+
322+
def test_priority_class_exists_permission_denied(mocker):
323+
from kubernetes.client import ApiException
324+
325+
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
326+
mock_api = mocker.patch("kubernetes.client.CustomObjectsApi")
327+
mock_api.return_value.get_cluster_custom_object.side_effect = ApiException(
328+
status=403
329+
)
330+
331+
assert priority_class_exists("some-priority") is None
332+
333+
334+
def test_priority_class_exists_other_error(mocker):
335+
from kubernetes.client import ApiException
336+
337+
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
338+
mock_api = mocker.patch("kubernetes.client.CustomObjectsApi")
339+
mock_api.return_value.get_cluster_custom_object.side_effect = ApiException(
340+
status=500
341+
)
342+
343+
assert priority_class_exists("some-priority") is None
344+
345+
295346
# Make sure to always keep this function last
296347
def test_cleanup():
297348
os.remove(f"{aw_dir}unit-test-cluster-kueue.yaml")

src/codeflare_sdk/ray/rayjobs/rayjob.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from typing import Dict, Any, Optional, Tuple, Union
2424

2525
from ray.runtime_env import RuntimeEnv
26-
from codeflare_sdk.common.kueue.kueue import get_default_kueue_name
26+
from codeflare_sdk.common.kueue.kueue import (
27+
get_default_kueue_name,
28+
priority_class_exists,
29+
)
2730
from codeflare_sdk.common.utils.constants import MOUNT_PATH
2831

2932
from codeflare_sdk.common.utils.utils import get_ray_image_for_python_version
@@ -69,6 +72,7 @@ def __init__(
6972
ttl_seconds_after_finished: int = 0,
7073
active_deadline_seconds: Optional[int] = None,
7174
local_queue: Optional[str] = None,
75+
priority_class: Optional[str] = None,
7276
):
7377
"""
7478
Initialize a RayJob instance.
@@ -86,11 +90,13 @@ def __init__(
8690
ttl_seconds_after_finished: Seconds to wait before cleanup after job finishes (default: 0)
8791
active_deadline_seconds: Maximum time the job can run before being terminated (optional)
8892
local_queue: The Kueue LocalQueue to submit the job to (optional)
93+
priority_class: The Kueue WorkloadPriorityClass name for preemption control (optional).
8994
9095
Note:
9196
- True if cluster_config is provided (new cluster will be cleaned up)
9297
- False if cluster_name is provided (existing cluster will not be shut down)
9398
- User can explicitly set this value to override auto-detection
99+
- Kueue labels (queue and priority) can be applied to both new and existing clusters
94100
"""
95101
if cluster_name is None and cluster_config is None:
96102
raise ValueError(
@@ -124,6 +130,7 @@ def __init__(
124130
self.ttl_seconds_after_finished = ttl_seconds_after_finished
125131
self.active_deadline_seconds = active_deadline_seconds
126132
self.local_queue = local_queue
133+
self.priority_class = priority_class
127134

128135
if namespace is None:
129136
detected_namespace = get_current_namespace()
@@ -165,6 +172,7 @@ def submit(self) -> str:
165172
# Validate configuration before submitting
166173
self._validate_ray_version_compatibility()
167174
self._validate_working_dir_entrypoint()
175+
self._validate_priority_class()
168176

169177
# Extract files from entrypoint and runtime_env working_dir
170178
files = extract_all_local_files(self)
@@ -243,12 +251,14 @@ def _build_rayjob_cr(self) -> Dict[str, Any]:
243251
# Extract files once and use for both runtime_env and submitter pod
244252
files = extract_all_local_files(self)
245253

254+
# Build Kueue labels - only for new clusters (lifecycled)
246255
labels = {}
247-
# If cluster_config is provided, use the local_queue from the cluster_config
256+
248257
if self._cluster_config is not None:
249258
if self.local_queue:
250259
labels["kueue.x-k8s.io/queue-name"] = self.local_queue
251260
else:
261+
# Auto-detect default queue for new clusters
252262
default_queue = get_default_kueue_name(self.namespace)
253263
if default_queue:
254264
labels["kueue.x-k8s.io/queue-name"] = default_queue
@@ -262,12 +272,23 @@ def _build_rayjob_cr(self) -> Dict[str, Any]:
262272
f"To fix this, please explicitly specify the 'local_queue' parameter."
263273
)
264274

265-
rayjob_cr["metadata"]["labels"] = labels
275+
if self.priority_class:
276+
labels["kueue.x-k8s.io/priority-class"] = self.priority_class
266277

267-
# When using Kueue (queue label present), start with suspend=true
268-
# Kueue will unsuspend the job once the workload is admitted
269-
if labels.get("kueue.x-k8s.io/queue-name"):
270-
rayjob_cr["spec"]["suspend"] = True
278+
# Apply labels to metadata
279+
if labels:
280+
rayjob_cr["metadata"]["labels"] = labels
281+
282+
# When using Kueue with lifecycled clusters, start with suspend=true
283+
# Kueue will unsuspend the job once the workload is admitted
284+
if labels.get("kueue.x-k8s.io/queue-name"):
285+
rayjob_cr["spec"]["suspend"] = True
286+
else:
287+
if self.local_queue or self.priority_class:
288+
logger.warning(
289+
f"Kueue labels (local_queue, priority_class) are ignored for RayJobs "
290+
f"targeting existing clusters. Kueue only manages RayJobs that create new clusters."
291+
)
271292

272293
# Add active deadline if specified
273294
if self.active_deadline_seconds:
@@ -450,6 +471,32 @@ def _validate_cluster_config_image(self):
450471
elif is_warning:
451472
warnings.warn(f"Cluster config image: {message}")
452473

474+
def _validate_priority_class(self):
475+
"""
476+
Validate that the priority class exists in the cluster (best effort).
477+
478+
Raises ValueError if the priority class is definitively known not to exist.
479+
If we cannot verify (e.g., permission denied), logs a warning and allows submission.
480+
"""
481+
if self.priority_class:
482+
logger.debug(f"Validating priority class '{self.priority_class}'...")
483+
exists = priority_class_exists(self.priority_class)
484+
485+
if exists is False:
486+
# Definitively doesn't exist - fail validation
487+
raise ValueError(
488+
f"Priority class '{self.priority_class}' does not exist"
489+
)
490+
elif exists is None:
491+
# Cannot verify - log warning and allow submission
492+
logger.warning(
493+
f"Could not verify if priority class '{self.priority_class}' exists. "
494+
f"Proceeding with submission - Kueue will validate on admission."
495+
)
496+
else:
497+
# exists is True - validation passed
498+
logger.debug(f"Priority class '{self.priority_class}' verified.")
499+
453500
def _validate_working_dir_entrypoint(self):
454501
"""
455502
Validate entrypoint file configuration.

0 commit comments

Comments
 (0)