@@ -129,18 +129,29 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
129129
130130 let mut patch = MirPatch :: new ( body) ;
131131
132- // create temp to store second discriminant in, `_s` in example above
133- let second_discriminant_temp =
134- patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
132+ let ( second_discriminant_temp, second_operand) = if opt_data. need_hoist_discriminant {
133+ // create temp to store second discriminant in, `_s` in example above
134+ let second_discriminant_temp =
135+ patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
135136
136- patch. add_statement ( parent_end, StatementKind :: StorageLive ( second_discriminant_temp) ) ;
137+ patch. add_statement (
138+ parent_end,
139+ StatementKind :: StorageLive ( second_discriminant_temp) ,
140+ ) ;
137141
138- // create assignment of discriminant
139- patch. add_assign (
140- parent_end,
141- Place :: from ( second_discriminant_temp) ,
142- Rvalue :: Discriminant ( opt_data. child_place ) ,
143- ) ;
142+ // create assignment of discriminant
143+ patch. add_assign (
144+ parent_end,
145+ Place :: from ( second_discriminant_temp) ,
146+ Rvalue :: Discriminant ( opt_data. child_place ) ,
147+ ) ;
148+ (
149+ Some ( second_discriminant_temp) ,
150+ Operand :: Move ( Place :: from ( second_discriminant_temp) ) ,
151+ )
152+ } else {
153+ ( None , Operand :: Copy ( opt_data. child_place ) )
154+ } ;
144155
145156 // create temp to store inequality comparison between the two discriminants, `_t` in
146157 // example above
@@ -149,11 +160,9 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
149160 let comp_temp = patch. new_temp ( comp_res_type, opt_data. child_source . span ) ;
150161 patch. add_statement ( parent_end, StatementKind :: StorageLive ( comp_temp) ) ;
151162
152- // create inequality comparison between the two discriminants
153- let comp_rvalue = Rvalue :: BinaryOp (
154- nequal,
155- Box :: new ( ( parent_op. clone ( ) , Operand :: Move ( Place :: from ( second_discriminant_temp) ) ) ) ,
156- ) ;
163+ // create inequality comparison
164+ let comp_rvalue =
165+ Rvalue :: BinaryOp ( nequal, Box :: new ( ( parent_op. clone ( ) , second_operand) ) ) ;
157166 patch. add_statement (
158167 parent_end,
159168 StatementKind :: Assign ( Box :: new ( ( Place :: from ( comp_temp) , comp_rvalue) ) ) ,
@@ -189,8 +198,13 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
189198 TerminatorKind :: if_ ( Operand :: Move ( Place :: from ( comp_temp) ) , true_case, false_case) ,
190199 ) ;
191200
192- // generate StorageDead for the second_discriminant_temp not in use anymore
193- patch. add_statement ( parent_end, StatementKind :: StorageDead ( second_discriminant_temp) ) ;
201+ if let Some ( second_discriminant_temp) = second_discriminant_temp {
202+ // generate StorageDead for the second_discriminant_temp not in use anymore
203+ patch. add_statement (
204+ parent_end,
205+ StatementKind :: StorageDead ( second_discriminant_temp) ,
206+ ) ;
207+ }
194208
195209 // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
196210 // the switch
@@ -218,6 +232,7 @@ struct OptimizationData<'tcx> {
218232 child_place : Place < ' tcx > ,
219233 child_ty : Ty < ' tcx > ,
220234 child_source : SourceInfo ,
235+ need_hoist_discriminant : bool ,
221236}
222237
223238fn evaluate_candidate < ' tcx > (
@@ -231,70 +246,128 @@ fn evaluate_candidate<'tcx>(
231246 return None ;
232247 } ;
233248 let parent_ty = parent_discr. ty ( body. local_decls ( ) , tcx) ;
234- if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
235- // Someone could write code like this:
236- // ```rust
237- // let Q = val;
238- // if discriminant(P) == otherwise {
239- // let ptr = &mut Q as *mut _ as *mut u8;
240- // // It may be difficult for us to effectively determine whether values are valid.
241- // // Invalid values can come from all sorts of corners.
242- // unsafe { *ptr = 10; }
243- // }
244- //
245- // match P {
246- // A => match Q {
247- // A => {
248- // // code
249- // }
250- // _ => {
251- // // don't use Q
252- // }
253- // }
254- // _ => {
255- // // don't use Q
256- // }
257- // };
258- // ```
259- //
260- // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant
261- // of an invalid value, which is UB.
262- // In order to fix this, **we would either need to show that the discriminant computation of
263- // `place` is computed in all branches**.
264- // FIXME(#95162) For the moment, we adopt a conservative approach and
265- // consider only the `otherwise` branch has no statements and an unreachable terminator.
266- return None ;
267- }
268249 let ( _, child) = targets. iter ( ) . next ( ) ?;
269- let child_terminator = & bbs[ child] . terminator ( ) ;
270- let TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } =
271- & child_terminator. kind
250+
251+ let Terminator {
252+ kind : TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } ,
253+ source_info,
254+ } = bbs[ child] . terminator ( )
272255 else {
273256 return None ;
274257 } ;
275258 let child_ty = child_discr. ty ( body. local_decls ( ) , tcx) ;
276259 if child_ty != parent_ty {
277260 return None ;
278261 }
279- let Some ( StatementKind :: Assign ( boxed) ) = & bbs[ child] . statements . first ( ) . map ( |x| & x. kind ) else {
262+
263+ // We only handle:
264+ // ```
265+ // bb4: {
266+ // _8 = discriminant((_3.1: Enum1));
267+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
268+ // }
269+ // ```
270+ // and
271+ // ```
272+ // bb2: {
273+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
274+ // }
275+ // ```
276+ if bbs[ child] . statements . len ( ) > 1 {
280277 return None ;
278+ }
279+
280+ // When thie BB has exactly one statement, this statement should be discriminant.
281+ let need_hoist_discriminant = bbs[ child] . statements . len ( ) == 1 ;
282+ let child_place = if need_hoist_discriminant {
283+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
284+ // Someone could write code like this:
285+ // ```rust
286+ // let Q = val;
287+ // if discriminant(P) == otherwise {
288+ // let ptr = &mut Q as *mut _ as *mut u8;
289+ // // It may be difficult for us to effectively determine whether values are valid.
290+ // // Invalid values can come from all sorts of corners.
291+ // unsafe { *ptr = 10; }
292+ // }
293+ //
294+ // match P {
295+ // A => match Q {
296+ // A => {
297+ // // code
298+ // }
299+ // _ => {
300+ // // don't use Q
301+ // }
302+ // }
303+ // _ => {
304+ // // don't use Q
305+ // }
306+ // };
307+ // ```
308+ //
309+ // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
310+ // invalid value, which is UB.
311+ // In order to fix this, **we would either need to show that the discriminant computation of
312+ // `place` is computed in all branches**.
313+ // FIXME(#95162) For the moment, we adopt a conservative approach and
314+ // consider only the `otherwise` branch has no statements and an unreachable terminator.
315+ return None ;
316+ }
317+ // Handle:
318+ // ```
319+ // bb4: {
320+ // _8 = discriminant((_3.1: Enum1));
321+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
322+ // }
323+ // ```
324+ let [
325+ Statement {
326+ kind : StatementKind :: Assign ( box ( _, Rvalue :: Discriminant ( child_place) ) ) ,
327+ ..
328+ } ,
329+ ] = bbs[ child] . statements . as_slice ( )
330+ else {
331+ return None ;
332+ } ;
333+ * child_place
334+ } else {
335+ // Handle:
336+ // ```
337+ // bb2: {
338+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
339+ // }
340+ // ```
341+ let Operand :: Copy ( child_place) = child_discr else {
342+ return None ;
343+ } ;
344+ * child_place
281345 } ;
282- let ( _, Rvalue :: Discriminant ( child_place) ) = & * * boxed else {
283- return None ;
346+ let destination = if need_hoist_discriminant || bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( )
347+ {
348+ child_targets. otherwise ( )
349+ } else {
350+ targets. otherwise ( )
284351 } ;
285- let destination = child_targets. otherwise ( ) ;
286352
287353 // Verify that the optimization is legal for each branch
288354 for ( value, child) in targets. iter ( ) {
289- if !verify_candidate_branch ( & bbs[ child] , value, * child_place, destination) {
355+ if !verify_candidate_branch (
356+ & bbs[ child] ,
357+ value,
358+ child_place,
359+ destination,
360+ need_hoist_discriminant,
361+ ) {
290362 return None ;
291363 }
292364 }
293365 Some ( OptimizationData {
294366 destination,
295- child_place : * child_place ,
367+ child_place,
296368 child_ty,
297- child_source : child_terminator. source_info ,
369+ child_source : * source_info,
370+ need_hoist_discriminant,
298371 } )
299372}
300373
@@ -303,31 +376,48 @@ fn verify_candidate_branch<'tcx>(
303376 value : u128 ,
304377 place : Place < ' tcx > ,
305378 destination : BasicBlock ,
379+ need_hoist_discriminant : bool ,
306380) -> bool {
307- // In order for the optimization to be correct, the branch must...
308- // ...have exactly one statement
309- if let [ statement] = branch. statements . as_slice ( )
310- // ...assign the discriminant of `place` in that statement
311- && let StatementKind :: Assign ( boxed) = & statement. kind
312- && let ( discr_place, Rvalue :: Discriminant ( from_place) ) = & * * boxed
313- && * from_place == place
314- // ...make that assignment to a local
315- && discr_place. projection . is_empty ( )
316- // ...terminate on a `SwitchInt` that invalidates that local
317- && let TerminatorKind :: SwitchInt { discr : switch_op, targets, .. } =
318- & branch. terminator ( ) . kind
319- && * switch_op == Operand :: Move ( * discr_place)
320- // ...fall through to `destination` if the switch misses
321- && destination == targets. otherwise ( )
322- // ...have a branch for value `value`
323- && let mut iter = targets. iter ( )
324- && let Some ( ( target_value, _) ) = iter. next ( )
325- && target_value == value
326- // ...and have no more branches
327- && iter. next ( ) . is_none ( )
328- {
329- true
381+ // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
382+ let TerminatorKind :: SwitchInt { discr : switch_op, targets } = & branch. terminator ( ) . kind else {
383+ return false ;
384+ } ;
385+ if need_hoist_discriminant {
386+ // If we need hoist discriminant, the branch must have exactly one statement.
387+ let [ statement] = branch. statements . as_slice ( ) else {
388+ return false ;
389+ } ;
390+ // The statement must assign the discriminant of `place`.
391+ let StatementKind :: Assign ( box ( discr_place, Rvalue :: Discriminant ( from_place) ) ) =
392+ statement. kind
393+ else {
394+ return false ;
395+ } ;
396+ if from_place != place {
397+ return false ;
398+ }
399+ // The assignment must invalidate a local that terminate on a `SwitchInt`.
400+ if !discr_place. projection . is_empty ( ) || * switch_op != Operand :: Move ( discr_place) {
401+ return false ;
402+ }
330403 } else {
331- false
404+ // If we don't need hoist discriminant, the branch must not have any statements.
405+ if !branch. statements . is_empty ( ) {
406+ return false ;
407+ }
408+ // The place on `SwitchInt` must be the same.
409+ if * switch_op != Operand :: Copy ( place) {
410+ return false ;
411+ }
332412 }
413+ // It must fall through to `destination` if the switch misses.
414+ if destination != targets. otherwise ( ) {
415+ return false ;
416+ }
417+ // It must have exactly one branch for value `value` and have no more branches.
418+ let mut iter = targets. iter ( ) ;
419+ let ( Some ( ( target_value, _) ) , None ) = ( iter. next ( ) , iter. next ( ) ) else {
420+ return false ;
421+ } ;
422+ target_value == value
333423}
0 commit comments