Skip to content

Commit eac8cef

Browse files
feat(RHOAIENG-26482): add gcs fault tolerance
1 parent 2eee6e2 commit eac8cef

File tree

4 files changed

+83
-2
lines changed

4 files changed

+83
-2
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .rayjob import RayJob, RayJobClusterConfig
22
from .status import RayJobDeploymentStatus, CodeflareRayJobStatus, RayJobInfo
3+
from .config import RayJobClusterConfig

src/codeflare_sdk/ray/rayjobs/config.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
"""
16-
The config sub-module contains the definition of the RayJobClusterConfigV2 dataclass,
16+
The config sub-module contains the definition of the RayJobClusterConfig dataclass,
1717
which is used to specify resource requirements and other details when creating a
1818
Cluster object.
1919
"""
@@ -139,6 +139,14 @@ class RayJobClusterConfig:
139139
A list of V1Volume objects to add to the Cluster
140140
volume_mounts:
141141
A list of V1VolumeMount objects to add to the Cluster
142+
enable_gcs_ft:
143+
A boolean indicating whether to enable GCS fault tolerance.
144+
redis_address:
145+
The address of the Redis server to use for GCS fault tolerance, required when enable_gcs_ft is True.
146+
redis_password_secret:
147+
Kubernetes secret reference containing Redis password. ex: {"name": "secret-name", "key": "password-key"}
148+
external_storage_namespace:
149+
The storage namespace to use for GCS fault tolerance. By default, KubeRay sets it to the UID of RayCluster.
142150
"""
143151

144152
head_cpu_requests: Union[int, str] = 2
@@ -165,8 +173,33 @@ class RayJobClusterConfig:
165173
annotations: Dict[str, str] = field(default_factory=dict)
166174
volumes: list[V1Volume] = field(default_factory=list)
167175
volume_mounts: list[V1VolumeMount] = field(default_factory=list)
176+
enable_gcs_ft: bool = False
177+
redis_address: Optional[str] = None
178+
redis_password_secret: Optional[Dict[str, str]] = None
179+
external_storage_namespace: Optional[str] = None
168180

169181
def __post_init__(self):
182+
if self.enable_gcs_ft:
183+
if not self.redis_address:
184+
raise ValueError(
185+
"redis_address must be provided when enable_gcs_ft is True"
186+
)
187+
188+
if self.redis_password_secret and not isinstance(
189+
self.redis_password_secret, dict
190+
):
191+
raise ValueError(
192+
"redis_password_secret must be a dictionary with 'name' and 'key' fields"
193+
)
194+
195+
if self.redis_password_secret and (
196+
"name" not in self.redis_password_secret
197+
or "key" not in self.redis_password_secret
198+
):
199+
raise ValueError(
200+
"redis_password_secret must contain both 'name' and 'key' fields"
201+
)
202+
170203
self._validate_types()
171204
self._memory_to_string()
172205
self._validate_gpu_config(self.head_accelerators)
@@ -251,6 +284,11 @@ def build_ray_cluster_spec(self, cluster_name: str) -> Dict[str, Any]:
251284
"workerGroupSpecs": [self._build_worker_group_spec(cluster_name)],
252285
}
253286

287+
# Add GCS fault tolerance if enabled
288+
if self.enable_gcs_ft:
289+
gcs_ft_options = self._build_gcs_ft_options()
290+
ray_cluster_spec["gcsFaultToleranceOptions"] = gcs_ft_options
291+
254292
return ray_cluster_spec
255293

256294
def _build_head_group_spec(self) -> Dict[str, Any]:
@@ -453,3 +491,25 @@ def _generate_volumes(self) -> list:
453491
def _build_env_vars(self) -> list:
454492
"""Build environment variables list."""
455493
return [V1EnvVar(name=key, value=value) for key, value in self.envs.items()]
494+
495+
def _build_gcs_ft_options(self) -> Dict[str, Any]:
496+
"""Build GCS fault tolerance options."""
497+
gcs_ft_options = {"redisAddress": self.redis_address}
498+
499+
if (
500+
hasattr(self, "external_storage_namespace")
501+
and self.external_storage_namespace
502+
):
503+
gcs_ft_options["externalStorageNamespace"] = self.external_storage_namespace
504+
505+
if hasattr(self, "redis_password_secret") and self.redis_password_secret:
506+
gcs_ft_options["redisPassword"] = {
507+
"valueFrom": {
508+
"secretKeyRef": {
509+
"name": self.redis_password_secret["name"],
510+
"key": self.redis_password_secret["key"],
511+
}
512+
}
513+
}
514+
515+
return gcs_ft_options

src/codeflare_sdk/ray/rayjobs/rayjob.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def __init__(
140140
self.cluster_name = cluster_name
141141
logger.info(f"Using existing cluster: {self.cluster_name}")
142142

143-
# Initialize the KubeRay job API client
144143
self._api = RayjobApi()
145144

146145
logger.info(f"Initialized RayJob: {self.name} in namespace: {self.namespace}")

src/codeflare_sdk/ray/rayjobs/test_rayjob.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,3 +971,24 @@ def test_rayjob_user_override_shutdown_behavior(mocker):
971971
)
972972

973973
assert rayjob_override_priority.shutdown_after_job_finishes is True
974+
975+
976+
def test_build_ray_cluster_spec_with_gcs_ft(mocker):
977+
"""Test build_ray_cluster_spec with GCS fault tolerance enabled."""
978+
from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig
979+
980+
# Create a test cluster config with GCS FT enabled
981+
cluster_config = RayJobClusterConfig(
982+
enable_gcs_ft=True,
983+
redis_address="redis://redis-service:6379",
984+
external_storage_namespace="storage-ns",
985+
)
986+
987+
# Build the spec using the method on the cluster config
988+
spec = cluster_config.build_ray_cluster_spec("test-cluster")
989+
990+
# Verify GCS fault tolerance options
991+
assert "gcsFaultToleranceOptions" in spec
992+
gcs_ft = spec["gcsFaultToleranceOptions"]
993+
assert gcs_ft["redisAddress"] == "redis://redis-service:6379"
994+
assert gcs_ft["externalStorageNamespace"] == "storage-ns"

0 commit comments

Comments
 (0)