|
6 | 6 | from typing import TYPE_CHECKING |
7 | 7 |
|
8 | 8 | from aws_durable_execution_sdk_python.lambda_service import ( |
| 9 | + OperationAction, |
9 | 10 | OperationType, |
10 | 11 | OperationUpdate, |
11 | 12 | ) |
@@ -83,6 +84,7 @@ def _validate_operation_update( |
83 | 84 | update: OperationUpdate, execution: Execution |
84 | 85 | ) -> None: |
85 | 86 | """Validate a single operation update.""" |
| 87 | + CheckpointValidator._validate_inconsistent_operation_metadata(update, execution) |
86 | 88 | CheckpointValidator._validate_payload_sizes(update) |
87 | 89 | ValidActionsByOperationTypeValidator.validate( |
88 | 90 | update.operation_type, update.action |
@@ -127,43 +129,112 @@ def _validate_operation_status_transition( |
127 | 129 |
|
128 | 130 | raise InvalidParameterValueException(msg) |
129 | 131 |
|
| 132 | + @staticmethod |
| 133 | + def _validate_inconsistent_operation_metadata( |
| 134 | + update: OperationUpdate, execution: Execution |
| 135 | + ) -> None: |
| 136 | + """Validate that operation metadata is consistent with existing operation.""" |
| 137 | + current_state = None |
| 138 | + for operation in execution.operations: |
| 139 | + if operation.operation_id == update.operation_id: |
| 140 | + current_state = operation |
| 141 | + break |
| 142 | + |
| 143 | + if current_state is not None: |
| 144 | + if ( |
| 145 | + update.operation_type is not None |
| 146 | + and update.operation_type != current_state.operation_type |
| 147 | + ): |
| 148 | + msg: str = "Inconsistent operation type." |
| 149 | + raise InvalidParameterValueException(msg) |
| 150 | + |
| 151 | + if ( |
| 152 | + update.sub_type is not None |
| 153 | + and update.sub_type != current_state.sub_type |
| 154 | + ): |
| 155 | + msg_subtype: str = "Inconsistent operation subtype." |
| 156 | + raise InvalidParameterValueException(msg_subtype) |
| 157 | + |
| 158 | + if update.name is not None and update.name != current_state.name: |
| 159 | + msg_name: str = "Inconsistent operation name." |
| 160 | + raise InvalidParameterValueException(msg_name) |
| 161 | + |
| 162 | + if ( |
| 163 | + update.parent_id is not None |
| 164 | + and update.parent_id != current_state.parent_id |
| 165 | + ): |
| 166 | + msg_parent: str = "Inconsistent parent operation id." |
| 167 | + raise InvalidParameterValueException(msg_parent) |
| 168 | + |
130 | 169 | @staticmethod |
131 | 170 | def _validate_parent_id_and_duplicate_id( |
132 | 171 | updates: list[OperationUpdate], execution: Execution |
133 | 172 | ) -> None: |
134 | | - """Validate parent IDs and check for duplicate operation IDs.""" |
135 | | - operations_seen: MutableMapping[str, OperationUpdate] = {} |
| 173 | + """Validate parent IDs and check for duplicate operation IDs. |
| 174 | +
|
| 175 | + Validate that any provided parentId is valid, and also validate no duplicate operation is being |
| 176 | + updated at the same time (unless it is a STEP/CONTEXT starting + performing one more non-START action). |
| 177 | + """ |
| 178 | + operations_started: MutableMapping[str, OperationUpdate] = {} |
| 179 | + last_updates_seen: MutableMapping[str, OperationUpdate] = {} |
136 | 180 |
|
137 | 181 | for update in updates: |
138 | | - if update.operation_id in operations_seen: |
139 | | - msg: str = "Cannot update the same operation twice in a single request." |
140 | | - raise InvalidParameterValueException(msg) |
| 182 | + if CheckpointValidator._is_invalid_duplicate_update( |
| 183 | + update, last_updates_seen |
| 184 | + ): |
| 185 | + msg_duplicate: str = ( |
| 186 | + "Cannot checkpoint multiple operations with the same ID." |
| 187 | + ) |
| 188 | + raise InvalidParameterValueException(msg_duplicate) |
141 | 189 |
|
142 | 190 | if not CheckpointValidator._is_valid_parent_for_update( |
143 | | - execution, update, operations_seen |
| 191 | + execution, update, operations_started |
144 | 192 | ): |
145 | | - msg_invalid_parent: str = "Invalid parent operation id." |
| 193 | + msg_parent: str = "Invalid parent operation id." |
| 194 | + raise InvalidParameterValueException(msg_parent) |
| 195 | + |
| 196 | + if update.action == OperationAction.START: |
| 197 | + operations_started[update.operation_id] = update |
| 198 | + |
| 199 | + last_updates_seen[update.operation_id] = update |
| 200 | + |
| 201 | + @staticmethod |
| 202 | + def _is_invalid_duplicate_update( |
| 203 | + update: OperationUpdate, last_updates_seen: MutableMapping[str, OperationUpdate] |
| 204 | + ) -> bool: |
| 205 | + """Check if this is an invalid duplicate update.""" |
| 206 | + last_update = last_updates_seen.get(update.operation_id) |
| 207 | + if last_update is None: |
| 208 | + return False |
146 | 209 |
|
147 | | - raise InvalidParameterValueException(msg_invalid_parent) |
| 210 | + if last_update.operation_type in (OperationType.STEP, OperationType.CONTEXT): |
| 211 | + # Allow duplicate for STEP/CONTEXT if last was START and current is not START |
| 212 | + allow_duplicate = ( |
| 213 | + last_update.action == OperationAction.START |
| 214 | + and update.action != OperationAction.START |
| 215 | + ) |
| 216 | + return not allow_duplicate |
148 | 217 |
|
149 | | - operations_seen[update.operation_id] = update |
| 218 | + return True |
150 | 219 |
|
151 | 220 | @staticmethod |
152 | 221 | def _is_valid_parent_for_update( |
153 | 222 | execution: Execution, |
154 | 223 | update: OperationUpdate, |
155 | | - operations_seen: MutableMapping[str, OperationUpdate], |
| 224 | + operations_started: MutableMapping[str, OperationUpdate], |
156 | 225 | ) -> bool: |
157 | 226 | """Check if the parent ID is valid for the update.""" |
158 | 227 | parent_id = update.parent_id |
159 | 228 |
|
160 | 229 | if parent_id is None: |
161 | 230 | return True |
162 | 231 |
|
163 | | - if parent_id in operations_seen: |
164 | | - parent_update = operations_seen[parent_id] |
| 232 | + # Check if parent is in operations started in this batch |
| 233 | + if parent_id in operations_started: |
| 234 | + parent_update = operations_started[parent_id] |
165 | 235 | return parent_update.operation_type == OperationType.CONTEXT |
166 | 236 |
|
| 237 | + # Check if parent exists in current execution state |
167 | 238 | for operation in execution.operations: |
168 | 239 | if operation.operation_id == parent_id: |
169 | 240 | return operation.operation_type == OperationType.CONTEXT |
|
0 commit comments