@@ -119,7 +119,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
119119
120120 const int * cache_indices = params.cache_indices_ptr == nullptr ? nullptr
121121 : reinterpret_cast <int *>(params.cache_indices_ptr );
122- const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
122+ const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
123123 // cache_index == params.pad_slot_id is defined as padding, so we exit early
124124 if (cache_index == params.pad_slot_id ){
125125 return ;
@@ -133,9 +133,18 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
133133 input_t *Bvar = reinterpret_cast <input_t *>(params.B_ptr ) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride ;
134134 weight_t *C = reinterpret_cast <weight_t *>(params.C_ptr ) + dim_id * kNRows * params.C_d_stride ;
135135 input_t *Cvar = reinterpret_cast <input_t *>(params.C_ptr ) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride ;
136- typename Ktraits::state_t *ssm_states = reinterpret_cast <typename Ktraits::state_t *>(params.ssm_states_ptr ) +
137- cache_index * params.ssm_states_batch_stride +
138- dim_id * kNRows * params.ssm_states_dim_stride ;
136+
137+ typename Ktraits::state_t *ssm_states;
138+ if (params.cache_enabled ) {
139+ // APC mode: ssm_states points to the base, we'll use absolute cache slots later
140+ ssm_states = reinterpret_cast <typename Ktraits::state_t *>(params.ssm_states_ptr ) +
141+ dim_id * kNRows * params.ssm_states_dim_stride ;
142+ } else {
143+ // Non-APC mode: offset by cache_index as before
144+ ssm_states = reinterpret_cast <typename Ktraits::state_t *>(params.ssm_states_ptr ) +
145+ cache_index * params.ssm_states_batch_stride +
146+ dim_id * kNRows * params.ssm_states_dim_stride ;
147+ }
139148
140149 float D_val[kNRows ] = {0 };
141150 if (params.D_ptr != nullptr ) {
@@ -159,7 +168,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
159168 // }
160169
161170 constexpr int kChunkSize = kNThreads * kNItems ;
162- const int n_chunks = (seqlen + 2048 - 1 ) / 2048 ;
171+
172+ // Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
173+ const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048 ;
174+ const int n_chunks = (seqlen + iteration_chunk_size - 1 ) / iteration_chunk_size;
175+
176+ const int * batch_cache_indices = cache_indices != nullptr ?
177+ cache_indices + batch_id * params.cache_indices_stride : nullptr ;
178+ const int * block_idx_first_scheduled = params.block_idx_first_scheduled_token_ptr != nullptr ?
179+ reinterpret_cast <const int *>(params.block_idx_first_scheduled_token_ptr ) : nullptr ;
180+ const int * block_idx_last_scheduled = params.block_idx_last_scheduled_token_ptr != nullptr ?
181+ reinterpret_cast <const int *>(params.block_idx_last_scheduled_token_ptr ) : nullptr ;
182+ const int * initial_state_idx = params.initial_state_idx_ptr != nullptr ?
183+ reinterpret_cast <const int *>(params.initial_state_idx_ptr ) : nullptr ;
184+
185+ const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
186+
163187 for (int chunk = 0 ; chunk < n_chunks; ++chunk) {
164188 input_t u_vals[kNRows ][kNItems ], delta_vals_load[kNRows ][kNItems ];
165189
@@ -219,7 +243,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
219243 if constexpr (kIsVariableC ) {
220244 auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
221245 load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride , C_vals,
222- smem_load_weight_C, (seqlen - chunk * kChunkSize ) * (1 ));
246+ smem_load_weight_C, (seqlen - chunk * kChunkSize ) * (1 ));
223247 if constexpr (!kIsVariableB ) {
224248 #pragma unroll
225249 for (int r = 0 ; r < kNRows ; ++r) {
@@ -242,16 +266,31 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
242266 for (int i = 0 ; i < kNItems ; ++i) {
243267 thread_data[i] = make_float2 (exp2f (delta_vals[r][i] * A_val[r]),
244268 !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
245-
246269 if (seqlen % (kNItems * kNThreads ) != 0 ) { // So that the last state is correct
247270 if (threadIdx .x * kNItems + i >= seqlen - chunk * kChunkSize ) {
248271 thread_data[i] = make_float2 (1 .f , 0 .f );
249272 }
250273 }
251274 }
252275 // Initialize running total
253-
254- scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2 (1.0 , has_initial_state ? float (ssm_states[state_idx * params.ssm_states_dstate_stride ]): 0.0 );
276+ scan_t running_prefix;
277+ if (chunk > 0 ) {
278+ running_prefix = smem_running_prefix[state_idx + r * MAX_DSTATE];
279+ } else {
280+ // Load initial state
281+ if (params.cache_enabled && has_initial_state && batch_cache_indices != nullptr ) {
282+ size_t state_offset = load_cache_slot * params.ssm_states_batch_stride +
283+ r * params.ssm_states_dim_stride +
284+ state_idx * params.ssm_states_dstate_stride ;
285+ running_prefix = make_float2 (1.0 , float (ssm_states[state_offset]));
286+ } else if (has_initial_state) {
287+ // Non-APC mode: load from current batch position
288+ running_prefix = make_float2 (1.0 , float (ssm_states[state_idx * params.ssm_states_dstate_stride ]));
289+ } else {
290+ // No initial state
291+ running_prefix = make_float2 (1.0 , 0.0 );
292+ }
293+ }
255294
256295 SSMScanPrefixCallbackOp<weight_t > prefix_op (running_prefix);
257296 typename Ktraits::BlockScanT (smem_scan).InclusiveScan (
@@ -260,8 +299,25 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
260299 // There's a syncthreads in the scan op, so we don't need to sync here.
261300 // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
262301 if (threadIdx .x == 0 ) {
263- smem_running_prefix[state_idx] = prefix_op.running_prefix ;
264- if (chunk == n_chunks - 1 ) {
302+ smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix ;
303+
304+ // Store state at the end of each chunk when cache is enabled
305+ if (params.cache_enabled && batch_cache_indices != nullptr ) {
306+
307+ size_t cache_slot;
308+ if (chunk == n_chunks - 1 ) {
309+ cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
310+ } else {
311+ cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
312+ }
313+
314+ size_t state_offset = cache_slot * params.ssm_states_batch_stride +
315+ r * params.ssm_states_dim_stride +
316+ state_idx * params.ssm_states_dstate_stride ;
317+
318+ ssm_states[state_offset] = typename Ktraits::state_t (prefix_op.running_prefix .y );
319+ } else if (!params.cache_enabled && chunk == n_chunks - 1 ) {
320+ // Non-APC mode: store only final state at current batch position
265321 ssm_states[state_idx * params.ssm_states_dstate_stride ] = typename Ktraits::state_t (prefix_op.running_prefix .y );
266322 }
267323 }
@@ -274,7 +330,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
274330 }
275331 }
276332 }
277-
278333 input_t *out = reinterpret_cast <input_t *>(params.out_ptr ) + sequence_start_index * params.out_batch_stride
279334 + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize ;
280335 __syncthreads ();
@@ -346,7 +401,9 @@ template<typename input_t, typename weight_t, typename state_t>
346401void selective_scan_fwd_cuda (SSMParamsBase ¶ms, cudaStream_t stream) {
347402
348403 #ifndef USE_ROCM
349- if (params.seqlen <= 128 ) {
404+ if (params.cache_enabled && params.block_size == 1024 ) {
405+ selective_scan_fwd_launch<64 , 16 , input_t , weight_t , state_t >(params, stream);
406+ } else if (params.seqlen <= 128 ) {
350407 selective_scan_fwd_launch<32 , 4 , input_t , weight_t , state_t >(params, stream);
351408 } else if (params.seqlen <= 256 ) {
352409 selective_scan_fwd_launch<32 , 8 , input_t , weight_t , state_t >(params, stream);
@@ -358,7 +415,9 @@ void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
358415 selective_scan_fwd_launch<128 , 16 , input_t , weight_t , state_t >(params, stream);
359416 }
360417 #else
361- if (params.seqlen <= 256 ) {
418+ if (params.cache_enabled && params.block_size == 1024 ) {
419+ selective_scan_fwd_launch<64 , 16 , input_t , weight_t , state_t >(params, stream);
420+ } else if (params.seqlen <= 256 ) {
362421 selective_scan_fwd_launch<64 , 4 , input_t , weight_t , state_t >(params, stream);
363422 } else if (params.seqlen <= 512 ) {
364423 selective_scan_fwd_launch<64 , 8 , input_t , weight_t , state_t >(params, stream);
@@ -437,13 +496,17 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
437496 const std::optional<at::Tensor>& D,
438497 const std::optional<at::Tensor>& delta_bias,
439498 const torch::Tensor ssm_states,
440- bool has_z,
499+ bool has_z,
441500 bool delta_softplus,
442501 const std::optional<at::Tensor>& query_start_loc,
443502 const std::optional<at::Tensor>& cache_indices,
444503 const std::optional<at::Tensor>& has_initial_state,
445504 bool varlen,
446- int64_t pad_slot_id) {
505+ int64_t pad_slot_id,
506+ int64_t block_size,
507+ const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
508+ const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
509+ const std::optional<torch::Tensor> &initial_state_idx) {
447510
448511 // Reset the parameters
449512 memset (¶ms, 0 , sizeof (params));
@@ -477,6 +540,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
477540 params.cache_indices_ptr = cache_indices.has_value () ? cache_indices.value ().data_ptr () : nullptr ;
478541 params.has_initial_state_ptr = has_initial_state.has_value () ? has_initial_state.value ().data_ptr () : nullptr ;
479542
543+ // Set cache parameters - cache is enabled if we have direct cache writing params
544+ params.cache_enabled = block_idx_first_scheduled_token.has_value ();
545+ params.block_size = static_cast <int >(block_size);
546+
547+ // Set direct cache writing pointers
548+ params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value () ? block_idx_first_scheduled_token.value ().data_ptr () : nullptr ;
549+ params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value () ? block_idx_last_scheduled_token.value ().data_ptr () : nullptr ;
550+ params.initial_state_idx_ptr = initial_state_idx.has_value () ? initial_state_idx.value ().data_ptr () : nullptr ;
480551
481552 // All stride are in elements, not bytes.
482553 params.A_d_stride = A.stride (0 );
@@ -504,9 +575,11 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
504575 params.out_d_stride = out.stride (0 );
505576
506577 params.ssm_states_batch_stride = ssm_states.stride (0 );
507- params.ssm_states_dim_stride = ssm_states.stride (1 );
578+ params.ssm_states_dim_stride = ssm_states.stride (1 );
508579 params.ssm_states_dstate_stride = ssm_states.stride (2 );
509580
581+ params.cache_indices_stride = cache_indices.has_value () ? cache_indices.value ().stride (0 ) : 0 ;
582+
510583 }
511584 else {
512585 if (!is_variable_B) {
@@ -537,8 +610,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
537610 params.out_d_stride = out.stride (1 );
538611
539612 params.ssm_states_batch_stride = ssm_states.stride (0 );
540- params.ssm_states_dim_stride = ssm_states.stride (1 );
613+ params.ssm_states_dim_stride = ssm_states.stride (1 );
541614 params.ssm_states_dstate_stride = ssm_states.stride (2 );
615+
616+ params.cache_indices_stride = cache_indices.has_value () ? cache_indices.value ().stride (0 ) : 0 ;
542617 }
543618}
544619
@@ -554,7 +629,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
554629 const torch::Tensor &ssm_states,
555630 // used to identify padding entries if cache_indices provided
556631 // in case of padding, the kernel will return early
557- int64_t pad_slot_id) {
632+ int64_t pad_slot_id,
633+ int64_t block_size,
634+ const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
635+ const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
636+ const std::optional<torch::Tensor> &initial_state_idx) {
558637 auto input_type = u.scalar_type ();
559638 auto weight_type = A.scalar_type ();
560639 TORCH_CHECK (input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -646,7 +725,16 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
646725 auto cache_indices_ = cache_indices.value ();
647726 TORCH_CHECK (cache_indices_.scalar_type () == at::ScalarType::Int);
648727 TORCH_CHECK (cache_indices_.is_cuda ());
649- CHECK_SHAPE (cache_indices_, batch_size);
728+
729+ // cache_indices can be either 1D (batch_size,) for non-APC mode
730+ // or 2D (batch_size, max_positions) for APC mode
731+ const bool is_apc_mode = block_idx_first_scheduled_token.has_value ();
732+ if (is_apc_mode) {
733+ TORCH_CHECK (cache_indices_.dim () == 2 , " cache_indices must be 2D for APC mode" );
734+ TORCH_CHECK (cache_indices_.size (0 ) == batch_size, " cache_indices first dimension must match batch_size" );
735+ } else {
736+ CHECK_SHAPE (cache_indices_, batch_size);
737+ }
650738 }
651739
652740
@@ -686,7 +774,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
686774 cache_indices,
687775 has_initial_state,
688776 varlen,
689- pad_slot_id
777+ pad_slot_id,
778+ block_size,
779+ block_idx_first_scheduled_token,
780+ block_idx_last_scheduled_token,
781+ initial_state_idx
690782 );
691783
692784
0 commit comments