From 39bf831a19397bf1dc2ce4718a6543f626a19e12 Mon Sep 17 00:00:00 2001 From: dianqk Date: Tue, 23 Dec 2025 22:25:11 +0800 Subject: [PATCH] New MIR Pass: SsaRangePropagation --- compiler/rustc_middle/src/mir/statement.rs | 6 + compiler/rustc_mir_transform/src/lib.rs | 2 + .../rustc_mir_transform/src/ssa_range_prop.rs | 203 ++++++++++++++++++ ...a_range.on_assert.SsaRangePropagation.diff | 69 ++++++ .../ssa_range.on_if.SsaRangePropagation.diff | 63 ++++++ ...ssa_range.on_if_2.SsaRangePropagation.diff | 20 ++ ...sa_range.on_match.SsaRangePropagation.diff | 33 +++ ..._range.on_match_2.SsaRangePropagation.diff | 26 +++ tests/mir-opt/range/ssa_range.rs | 70 ++++++ 9 files changed, 492 insertions(+) create mode 100644 compiler/rustc_mir_transform/src/ssa_range_prop.rs create mode 100644 tests/mir-opt/range/ssa_range.on_assert.SsaRangePropagation.diff create mode 100644 tests/mir-opt/range/ssa_range.on_if.SsaRangePropagation.diff create mode 100644 tests/mir-opt/range/ssa_range.on_if_2.SsaRangePropagation.diff create mode 100644 tests/mir-opt/range/ssa_range.on_match.SsaRangePropagation.diff create mode 100644 tests/mir-opt/range/ssa_range.on_match_2.SsaRangePropagation.diff create mode 100644 tests/mir-opt/range/ssa_range.rs diff --git a/compiler/rustc_middle/src/mir/statement.rs b/compiler/rustc_middle/src/mir/statement.rs index 1ba1ae3e1531d..2ee1d53cabd53 100644 --- a/compiler/rustc_middle/src/mir/statement.rs +++ b/compiler/rustc_middle/src/mir/statement.rs @@ -374,6 +374,12 @@ impl<'tcx> Place<'tcx> { self.projection.iter().any(|elem| elem.is_indirect()) } + /// Returns `true` if the `Place` always refers to the same memory region + /// whatever the state of the program. + pub fn is_stable_offset(&self) -> bool { + self.projection.iter().all(|elem| elem.is_stable_offset()) + } + /// Returns `true` if this `Place`'s first projection is `Deref`. /// /// This is useful because for MIR phases `AnalysisPhase::PostCleanup` and later, diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 701d7ff854a75..55f8d1bfbded1 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -198,6 +198,7 @@ declare_passes! { mod single_use_consts : SingleUseConsts; mod sroa : ScalarReplacementOfAggregates; mod strip_debuginfo : StripDebugInfo; + mod ssa_range_prop: SsaRangePropagation; mod unreachable_enum_branching : UnreachableEnumBranching; mod unreachable_prop : UnreachablePropagation; mod validate : Validator; @@ -743,6 +744,7 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<' &simplify::SimplifyLocals::AfterGVN, &match_branches::MatchBranchSimplification, &dataflow_const_prop::DataflowConstProp, + &ssa_range_prop::SsaRangePropagation, &single_use_consts::SingleUseConsts, &o1(simplify_branches::SimplifyConstCondition::AfterConstProp), &jump_threading::JumpThreading, diff --git a/compiler/rustc_mir_transform/src/ssa_range_prop.rs b/compiler/rustc_mir_transform/src/ssa_range_prop.rs new file mode 100644 index 0000000000000..752dd2eecf101 --- /dev/null +++ b/compiler/rustc_mir_transform/src/ssa_range_prop.rs @@ -0,0 +1,203 @@ +use rustc_abi::WrappingRange; +use rustc_const_eval::interpret::Scalar; +use rustc_data_structures::fx::FxHashMap; +use rustc_data_structures::graph::dominators::Dominators; +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::mir::visit::MutVisitor; +use rustc_middle::mir::{BasicBlock, Body, Location, Operand, Place, TerminatorKind, *}; +use rustc_middle::ty::{TyCtxt, TypingEnv}; +use rustc_span::DUMMY_SP; + +use crate::ssa::SsaLocals; + +pub(super) struct SsaRangePropagation; + +impl<'tcx> crate::MirPass<'tcx> for SsaRangePropagation { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 1 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let typing_env = body.typing_env(tcx); + let ssa = SsaLocals::new(tcx, body, typing_env); + // Clone dominators because we need them while mutating the body. + let dominators = body.basic_blocks.dominators().clone(); + let mut range_set = + RangeSet::new(tcx, typing_env, body, &ssa, &body.local_decls, dominators); + + let reverse_postorder = body.basic_blocks.reverse_postorder().to_vec(); + for bb in reverse_postorder { + let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb]; + range_set.visit_basic_block_data(bb, data); + } + } + + fn is_required(&self) -> bool { + false + } +} + +struct RangeSet<'tcx, 'body, 'a> { + tcx: TyCtxt<'tcx>, + typing_env: TypingEnv<'tcx>, + ssa: &'a SsaLocals, + local_decls: &'body LocalDecls<'tcx>, + dominators: Dominators, + /// Known ranges at each locations. + ranges: FxHashMap, Vec<(Location, WrappingRange)>>, + /// Determines if the basic block has a single unique predecessor. + unique_predecessors: DenseBitSet, +} + +impl<'tcx, 'body, 'a> RangeSet<'tcx, 'body, 'a> { + fn new( + tcx: TyCtxt<'tcx>, + typing_env: TypingEnv<'tcx>, + body: &Body<'tcx>, + ssa: &'a SsaLocals, + local_decls: &'body LocalDecls<'tcx>, + dominators: Dominators, + ) -> Self { + let predecessors = body.basic_blocks.predecessors(); + let mut unique_predecessors = DenseBitSet::new_empty(body.basic_blocks.len()); + for (bb, _) in body.basic_blocks.iter_enumerated() { + if predecessors[bb].len() == 1 { + unique_predecessors.insert(bb); + } + } + RangeSet { + tcx, + typing_env, + ssa, + local_decls, + dominators, + ranges: FxHashMap::default(), + unique_predecessors, + } + } + + /// Create a new known range at the location. + fn insert_range(&mut self, place: Place<'tcx>, location: Location, range: WrappingRange) { + self.ranges.entry(place).or_default().push((location, range)); + } + + /// Get the known range at the location. + fn get_range(&self, place: &Place<'tcx>, location: Location) -> Option { + let Some(ranges) = self.ranges.get(place) else { + return None; + }; + // FIXME: This should use the intersection of all valid ranges. + let (_, range) = + ranges.iter().find(|(range_loc, _)| range_loc.dominates(location, &self.dominators))?; + Some(*range) + } + + fn try_as_constant( + &mut self, + place: Place<'tcx>, + location: Location, + ) -> Option> { + if let Some(range) = self.get_range(&place, location) + && range.start == range.end + { + let ty = place.ty(self.local_decls, self.tcx).ty; + let layout = self.tcx.layout_of(self.typing_env.as_query_input(ty)).ok()?; + let value = ConstValue::Scalar(Scalar::from_uint(range.start, layout.size)); + let const_ = Const::Val(value, ty); + return Some(ConstOperand { span: DUMMY_SP, user_ty: None, const_ }); + } + None + } + + fn simplify_operand( + &mut self, + operand: &mut Operand<'tcx>, + location: Location, + ) -> Result<(), Option>> { + let Some(place) = operand.place() else { + return Ok(()); + }; + let Some(const_) = self.try_as_constant(place, location) else { + if self.is_ssa(place) { + return Err(Some(place)); + } else { + return Err(None); + } + }; + *operand = Operand::Constant(Box::new(const_)); + Ok(()) + } + + fn is_ssa(&self, place: Place<'tcx>) -> bool { + self.ssa.is_ssa(place.local) && place.is_stable_offset() + } +} + +impl<'tcx> MutVisitor<'tcx> for RangeSet<'tcx, '_, '_> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) { + let _ = self.simplify_operand(operand, location); + } + + fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) { + match &mut terminator.kind { + TerminatorKind::Assert { cond, expected, target, .. } => { + if let Err(Some(place)) = self.simplify_operand(cond, location) { + let successor = Location { block: *target, statement_index: 0 }; + if location.block != successor.block + && self.unique_predecessors.contains(successor.block) + { + let val = *expected as u128; + let range = WrappingRange { start: val, end: val }; + self.insert_range(place, successor, range); + } + } + } + TerminatorKind::SwitchInt { discr, targets } => { + if let Err(Some(place)) = self.simplify_operand(discr, location) + && targets.all_targets().len() < 8 + { + let mut distinct_targets: FxHashMap = FxHashMap::default(); + for (_, target) in targets.iter() { + let targets = distinct_targets.entry(target).or_default(); + if *targets == 0 { + *targets = 1; + } else { + *targets = 2; + } + } + for (val, target) in targets.iter() { + if distinct_targets[&target] != 1 { + continue; + } + let successor = Location { block: target, statement_index: 0 }; + if location.block != successor.block + && self.unique_predecessors.contains(successor.block) + { + let range = WrappingRange { start: val, end: val }; + self.insert_range(place, successor, range); + } + } + + let otherwise = Location { block: targets.otherwise(), statement_index: 0 }; + if place.ty(self.local_decls, self.tcx).ty.is_bool() + && let [val] = targets.all_values() + && location.block != otherwise.block + && self.unique_predecessors.contains(otherwise.block) + { + let range = if val.get() == 0 { + WrappingRange { start: 1, end: 1 } + } else { + WrappingRange { start: 0, end: 0 } + }; + self.insert_range(place, otherwise, range); + } + } + } + _ => {} + } + } +} diff --git a/tests/mir-opt/range/ssa_range.on_assert.SsaRangePropagation.diff b/tests/mir-opt/range/ssa_range.on_assert.SsaRangePropagation.diff new file mode 100644 index 0000000000000..ae3f49a8847b0 --- /dev/null +++ b/tests/mir-opt/range/ssa_range.on_assert.SsaRangePropagation.diff @@ -0,0 +1,69 @@ +- // MIR for `on_assert` before SsaRangePropagation ++ // MIR for `on_assert` after SsaRangePropagation + + fn on_assert(_1: usize, _2: &[u8]) -> u8 { + debug i => _1; + debug v => _2; + let mut _0: u8; + let _3: (); + let mut _4: bool; + let mut _5: usize; + let mut _6: usize; + let mut _7: &[u8]; + let mut _8: !; + let _9: usize; + let mut _10: usize; + let mut _11: bool; + scope 1 (inlined core::slice::::len) { + scope 2 (inlined std::ptr::metadata::<[u8]>) { + } + } + + bb0: { + StorageLive(_3); + nop; + StorageLive(_5); + _5 = copy _1; + nop; + StorageLive(_7); + _7 = &(*_2); + _6 = PtrMetadata(copy _2); + StorageDead(_7); + _4 = Lt(copy _1, copy _6); + switchInt(copy _4) -> [0: bb2, otherwise: bb1]; + } + + bb1: { + nop; + StorageDead(_5); + _3 = const (); + nop; + StorageDead(_3); + StorageLive(_9); + _9 = copy _1; + _10 = copy _6; +- _11 = copy _4; +- assert(copy _4, "index out of bounds: the length is {} but the index is {}", copy _6, copy _1) -> [success: bb3, unwind unreachable]; ++ _11 = const true; ++ assert(const true, "index out of bounds: the length is {} but the index is {}", copy _6, copy _1) -> [success: bb3, unwind unreachable]; + } + + bb2: { + nop; + StorageDead(_5); + StorageLive(_8); + _8 = panic(const "assertion failed: i < v.len()") -> unwind unreachable; + } + + bb3: { + _0 = copy (*_2)[_1]; + StorageDead(_9); + return; + } + } + + ALLOC0 (size: 29, align: 1) { + 0x00 │ 61 73 73 65 72 74 69 6f 6e 20 66 61 69 6c 65 64 │ assertion failed + 0x10 │ 3a 20 69 20 3c 20 76 2e 6c 65 6e 28 29 │ : i < v.len() + } + diff --git a/tests/mir-opt/range/ssa_range.on_if.SsaRangePropagation.diff b/tests/mir-opt/range/ssa_range.on_if.SsaRangePropagation.diff new file mode 100644 index 0000000000000..2493e069edd4c --- /dev/null +++ b/tests/mir-opt/range/ssa_range.on_if.SsaRangePropagation.diff @@ -0,0 +1,63 @@ +- // MIR for `on_if` before SsaRangePropagation ++ // MIR for `on_if` after SsaRangePropagation + + fn on_if(_1: usize, _2: &[u8]) -> u8 { + debug i => _1; + debug v => _2; + let mut _0: u8; + let mut _3: bool; + let mut _4: usize; + let mut _5: usize; + let mut _6: &[u8]; + let _7: usize; + let mut _8: usize; + let mut _9: bool; + scope 1 (inlined core::slice::::len) { + scope 2 (inlined std::ptr::metadata::<[u8]>) { + } + } + + bb0: { + nop; + StorageLive(_4); + _4 = copy _1; + nop; + StorageLive(_6); + _6 = &(*_2); + _5 = PtrMetadata(copy _2); + StorageDead(_6); + _3 = Lt(copy _1, copy _5); + switchInt(copy _3) -> [0: bb3, otherwise: bb1]; + } + + bb1: { + nop; + StorageDead(_4); + StorageLive(_7); + _7 = copy _1; + _8 = copy _5; +- _9 = copy _3; +- assert(copy _3, "index out of bounds: the length is {} but the index is {}", copy _5, copy _1) -> [success: bb2, unwind unreachable]; ++ _9 = const true; ++ assert(const true, "index out of bounds: the length is {} but the index is {}", copy _5, copy _1) -> [success: bb2, unwind unreachable]; + } + + bb2: { + _0 = copy (*_2)[_1]; + StorageDead(_7); + goto -> bb4; + } + + bb3: { + nop; + StorageDead(_4); + _0 = const 0_u8; + goto -> bb4; + } + + bb4: { + nop; + return; + } + } + diff --git a/tests/mir-opt/range/ssa_range.on_if_2.SsaRangePropagation.diff b/tests/mir-opt/range/ssa_range.on_if_2.SsaRangePropagation.diff new file mode 100644 index 0000000000000..8a957238c8453 --- /dev/null +++ b/tests/mir-opt/range/ssa_range.on_if_2.SsaRangePropagation.diff @@ -0,0 +1,20 @@ +- // MIR for `on_if_2` before SsaRangePropagation ++ // MIR for `on_if_2` after SsaRangePropagation + + fn on_if_2(_1: bool) -> bool { + let mut _0: bool; + + bb0: { + switchInt(copy _1) -> [1: bb2, otherwise: bb1]; + } + + bb1: { + goto -> bb2; + } + + bb2: { + _0 = copy _1; + return; + } + } + diff --git a/tests/mir-opt/range/ssa_range.on_match.SsaRangePropagation.diff b/tests/mir-opt/range/ssa_range.on_match.SsaRangePropagation.diff new file mode 100644 index 0000000000000..f91ac7090dde3 --- /dev/null +++ b/tests/mir-opt/range/ssa_range.on_match.SsaRangePropagation.diff @@ -0,0 +1,33 @@ +- // MIR for `on_match` before SsaRangePropagation ++ // MIR for `on_match` after SsaRangePropagation + + fn on_match(_1: u8) -> u8 { + debug i => _1; + let mut _0: u8; + + bb0: { + switchInt(copy _1) -> [1: bb3, 2: bb2, otherwise: bb1]; + } + + bb1: { + _0 = const 0_u8; + goto -> bb4; + } + + bb2: { +- _0 = copy _1; ++ _0 = const 2_u8; + goto -> bb4; + } + + bb3: { +- _0 = copy _1; ++ _0 = const 1_u8; + goto -> bb4; + } + + bb4: { + return; + } + } + diff --git a/tests/mir-opt/range/ssa_range.on_match_2.SsaRangePropagation.diff b/tests/mir-opt/range/ssa_range.on_match_2.SsaRangePropagation.diff new file mode 100644 index 0000000000000..53433d9fe4ba1 --- /dev/null +++ b/tests/mir-opt/range/ssa_range.on_match_2.SsaRangePropagation.diff @@ -0,0 +1,26 @@ +- // MIR for `on_match_2` before SsaRangePropagation ++ // MIR for `on_match_2` after SsaRangePropagation + + fn on_match_2(_1: u8) -> u8 { + debug i => _1; + let mut _0: u8; + + bb0: { + switchInt(copy _1) -> [1: bb2, 2: bb2, otherwise: bb1]; + } + + bb1: { + _0 = const 0_u8; + goto -> bb3; + } + + bb2: { + _0 = copy _1; + goto -> bb3; + } + + bb3: { + return; + } + } + diff --git a/tests/mir-opt/range/ssa_range.rs b/tests/mir-opt/range/ssa_range.rs new file mode 100644 index 0000000000000..964d9b97cf92d --- /dev/null +++ b/tests/mir-opt/range/ssa_range.rs @@ -0,0 +1,70 @@ +//@ test-mir-pass: SsaRangePropagation +//@ compile-flags: -Zmir-enable-passes=+GVN,+Inline --crate-type=lib -Cpanic=abort + +#![feature(custom_mir, core_intrinsics)] + +use std::intrinsics::mir::*; + +// EMIT_MIR ssa_range.on_if.SsaRangePropagation.diff +pub fn on_if(i: usize, v: &[u8]) -> u8 { + // CHECK-LABEL: fn on_if( + // CHECK: assert(const true + if i < v.len() { v[i] } else { 0 } +} + +// EMIT_MIR ssa_range.on_assert.SsaRangePropagation.diff +pub fn on_assert(i: usize, v: &[u8]) -> u8 { + // CHECK-LABEL: fn on_assert( + // CHECK: assert(const true + assert!(i < v.len()); + v[i] +} + +// EMIT_MIR ssa_range.on_match.SsaRangePropagation.diff +pub fn on_match(i: u8) -> u8 { + // CHECK-LABEL: fn on_match( + // CHECK: switchInt(copy _1) -> [1: [[BB_V1:bb.*]], 2: [[BB_V2:bb.*]], + // CHECK: [[BB_V2]]: { + // CHECK-NEXT: _0 = const 2_u8; + // CHECK: [[BB_V1]]: { + // CHECK-NEXT: _0 = const 1_u8; + match i { + 1 => i, + 2 => i, + _ => 0, + } +} + +// EMIT_MIR ssa_range.on_match_2.SsaRangePropagation.diff +pub fn on_match_2(i: u8) -> u8 { + // CHECK-LABEL: fn on_match_2( + // CHECK: switchInt(copy _1) -> [1: [[BB:bb.*]], 2: [[BB]], + // CHECK: [[BB]]: { + // CHECK-NEXT: _0 = copy _1; + match i { + 1 | 2 => i, + _ => 0, + } +} + +// EMIT_MIR ssa_range.on_if_2.SsaRangePropagation.diff +#[custom_mir(dialect = "runtime", phase = "post-cleanup")] +pub fn on_if_2(a: bool) -> bool { + // CHECK-LABEL: fn on_if_2( + // CHECK: _0 = copy _1; + mir! { + { + match a { + true => bb2, + _ => bb1 + } + } + bb1 = { + Goto(bb2) + } + bb2 = { + RET = a; + Return() + } + } +}