Skip to content

Commit 00b31a3

Browse files
authored
[V1] [Hybrid] Mamba1 Automatic Prefix Caching (#26377)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
1 parent 73444b7 commit 00b31a3

File tree

16 files changed

+442
-153
lines changed

16 files changed

+442
-153
lines changed

csrc/mamba/mamba_ssm/selective_scan.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ struct SSMParamsBase {
2424
int64_t pad_slot_id;
2525

2626
bool delta_softplus;
27+
bool cache_enabled;
28+
int block_size;
2729

2830
index_t A_d_stride;
2931
index_t A_dstate_stride;
@@ -46,8 +48,9 @@ struct SSMParamsBase {
4648
index_t out_z_batch_stride;
4749
index_t out_z_d_stride;
4850
index_t ssm_states_batch_stride;
49-
index_t ssm_states_dim_stride;
51+
index_t ssm_states_dim_stride;
5052
index_t ssm_states_dstate_stride;
53+
index_t cache_indices_stride;
5154

5255
// Common data pointers.
5356
void *__restrict__ A_ptr;
@@ -66,6 +69,9 @@ struct SSMParamsBase {
6669
void *__restrict__ cache_indices_ptr;
6770
void *__restrict__ has_initial_state_ptr;
6871

72+
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
73+
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
74+
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
6975
};
7076

7177

csrc/mamba/mamba_ssm/selective_scan_fwd.cu

Lines changed: 113 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
119119

120120
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
121121
: reinterpret_cast<int *>(params.cache_indices_ptr);
122-
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
122+
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
123123
// cache_index == params.pad_slot_id is defined as padding, so we exit early
124124
if (cache_index == params.pad_slot_id){
125125
return;
@@ -133,9 +133,18 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
133133
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
134134
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
135135
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
136-
typename Ktraits::state_t *ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
137-
cache_index * params.ssm_states_batch_stride +
138-
dim_id * kNRows * params.ssm_states_dim_stride;
136+
137+
typename Ktraits::state_t *ssm_states;
138+
if (params.cache_enabled) {
139+
// APC mode: ssm_states points to the base, we'll use absolute cache slots later
140+
ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
141+
dim_id * kNRows * params.ssm_states_dim_stride;
142+
} else {
143+
// Non-APC mode: offset by cache_index as before
144+
ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
145+
cache_index * params.ssm_states_batch_stride +
146+
dim_id * kNRows * params.ssm_states_dim_stride;
147+
}
139148

140149
float D_val[kNRows] = {0};
141150
if (params.D_ptr != nullptr) {
@@ -159,7 +168,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
159168
// }
160169

161170
constexpr int kChunkSize = kNThreads * kNItems;
162-
const int n_chunks = (seqlen + 2048 - 1) / 2048;
171+
172+
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
173+
const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048;
174+
const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size;
175+
176+
const int* batch_cache_indices = cache_indices != nullptr ?
177+
cache_indices + batch_id * params.cache_indices_stride : nullptr;
178+
const int* block_idx_first_scheduled = params.block_idx_first_scheduled_token_ptr != nullptr ?
179+
reinterpret_cast<const int*>(params.block_idx_first_scheduled_token_ptr) : nullptr;
180+
const int* block_idx_last_scheduled = params.block_idx_last_scheduled_token_ptr != nullptr ?
181+
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
182+
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
183+
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
184+
185+
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
186+
163187
for (int chunk = 0; chunk < n_chunks; ++chunk) {
164188
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
165189

@@ -219,7 +243,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
219243
if constexpr (kIsVariableC) {
220244
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
221245
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
222-
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 ));
246+
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1));
223247
if constexpr (!kIsVariableB) {
224248
#pragma unroll
225249
for (int r = 0; r < kNRows; ++r) {
@@ -242,16 +266,31 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
242266
for (int i = 0; i < kNItems; ++i) {
243267
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
244268
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
245-
246269
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
247270
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
248271
thread_data[i] = make_float2(1.f, 0.f);
249272
}
250273
}
251274
}
252275
// Initialize running total
253-
254-
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);
276+
scan_t running_prefix;
277+
if (chunk > 0) {
278+
running_prefix = smem_running_prefix[state_idx + r * MAX_DSTATE];
279+
} else {
280+
// Load initial state
281+
if (params.cache_enabled && has_initial_state && batch_cache_indices != nullptr) {
282+
size_t state_offset = load_cache_slot * params.ssm_states_batch_stride +
283+
r * params.ssm_states_dim_stride +
284+
state_idx * params.ssm_states_dstate_stride;
285+
running_prefix = make_float2(1.0, float(ssm_states[state_offset]));
286+
} else if (has_initial_state) {
287+
// Non-APC mode: load from current batch position
288+
running_prefix = make_float2(1.0, float(ssm_states[state_idx * params.ssm_states_dstate_stride]));
289+
} else {
290+
// No initial state
291+
running_prefix = make_float2(1.0, 0.0);
292+
}
293+
}
255294

256295
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
257296
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
@@ -260,8 +299,25 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
260299
// There's a syncthreads in the scan op, so we don't need to sync here.
261300
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
262301
if (threadIdx.x == 0) {
263-
smem_running_prefix[state_idx] = prefix_op.running_prefix;
264-
if (chunk == n_chunks - 1) {
302+
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
303+
304+
// Store state at the end of each chunk when cache is enabled
305+
if (params.cache_enabled && batch_cache_indices != nullptr) {
306+
307+
size_t cache_slot;
308+
if (chunk == n_chunks - 1) {
309+
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
310+
} else {
311+
cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
312+
}
313+
314+
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
315+
r * params.ssm_states_dim_stride +
316+
state_idx * params.ssm_states_dstate_stride;
317+
318+
ssm_states[state_offset] = typename Ktraits::state_t(prefix_op.running_prefix.y);
319+
} else if (!params.cache_enabled && chunk == n_chunks - 1) {
320+
// Non-APC mode: store only final state at current batch position
265321
ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y);
266322
}
267323
}
@@ -274,7 +330,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
274330
}
275331
}
276332
}
277-
278333
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
279334
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
280335
__syncthreads();
@@ -346,7 +401,9 @@ template<typename input_t, typename weight_t, typename state_t>
346401
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
347402

348403
#ifndef USE_ROCM
349-
if (params.seqlen <= 128) {
404+
if (params.cache_enabled && params.block_size == 1024) {
405+
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
406+
} else if (params.seqlen <= 128) {
350407
selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream);
351408
} else if (params.seqlen <= 256) {
352409
selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream);
@@ -358,7 +415,9 @@ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
358415
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
359416
}
360417
#else
361-
if (params.seqlen <= 256) {
418+
if (params.cache_enabled && params.block_size == 1024) {
419+
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
420+
} else if (params.seqlen <= 256) {
362421
selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream);
363422
} else if (params.seqlen <= 512) {
364423
selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream);
@@ -437,13 +496,17 @@ void set_ssm_params_fwd(SSMParamsBase &params,
437496
const std::optional<at::Tensor>& D,
438497
const std::optional<at::Tensor>& delta_bias,
439498
const torch::Tensor ssm_states,
440-
bool has_z,
499+
bool has_z,
441500
bool delta_softplus,
442501
const std::optional<at::Tensor>& query_start_loc,
443502
const std::optional<at::Tensor>& cache_indices,
444503
const std::optional<at::Tensor>& has_initial_state,
445504
bool varlen,
446-
int64_t pad_slot_id) {
505+
int64_t pad_slot_id,
506+
int64_t block_size,
507+
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
508+
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
509+
const std::optional<torch::Tensor> &initial_state_idx) {
447510

448511
// Reset the parameters
449512
memset(&params, 0, sizeof(params));
@@ -477,6 +540,14 @@ void set_ssm_params_fwd(SSMParamsBase &params,
477540
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
478541
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
479542

543+
// Set cache parameters - cache is enabled if we have direct cache writing params
544+
params.cache_enabled = block_idx_first_scheduled_token.has_value();
545+
params.block_size = static_cast<int>(block_size);
546+
547+
// Set direct cache writing pointers
548+
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
549+
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
550+
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
480551

481552
// All stride are in elements, not bytes.
482553
params.A_d_stride = A.stride(0);
@@ -504,9 +575,11 @@ void set_ssm_params_fwd(SSMParamsBase &params,
504575
params.out_d_stride = out.stride(0);
505576

506577
params.ssm_states_batch_stride = ssm_states.stride(0);
507-
params.ssm_states_dim_stride = ssm_states.stride(1);
578+
params.ssm_states_dim_stride = ssm_states.stride(1);
508579
params.ssm_states_dstate_stride = ssm_states.stride(2);
509580

581+
params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0;
582+
510583
}
511584
else{
512585
if (!is_variable_B) {
@@ -537,8 +610,10 @@ void set_ssm_params_fwd(SSMParamsBase &params,
537610
params.out_d_stride = out.stride(1);
538611

539612
params.ssm_states_batch_stride = ssm_states.stride(0);
540-
params.ssm_states_dim_stride = ssm_states.stride(1);
613+
params.ssm_states_dim_stride = ssm_states.stride(1);
541614
params.ssm_states_dstate_stride = ssm_states.stride(2);
615+
616+
params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0;
542617
}
543618
}
544619

@@ -554,7 +629,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
554629
const torch::Tensor &ssm_states,
555630
// used to identify padding entries if cache_indices provided
556631
// in case of padding, the kernel will return early
557-
int64_t pad_slot_id) {
632+
int64_t pad_slot_id,
633+
int64_t block_size,
634+
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
635+
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
636+
const std::optional<torch::Tensor> &initial_state_idx) {
558637
auto input_type = u.scalar_type();
559638
auto weight_type = A.scalar_type();
560639
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -646,7 +725,16 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
646725
auto cache_indices_ = cache_indices.value();
647726
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
648727
TORCH_CHECK(cache_indices_.is_cuda());
649-
CHECK_SHAPE(cache_indices_, batch_size);
728+
729+
// cache_indices can be either 1D (batch_size,) for non-APC mode
730+
// or 2D (batch_size, max_positions) for APC mode
731+
const bool is_apc_mode = block_idx_first_scheduled_token.has_value();
732+
if (is_apc_mode) {
733+
TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode");
734+
TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size");
735+
} else {
736+
CHECK_SHAPE(cache_indices_, batch_size);
737+
}
650738
}
651739

652740

@@ -686,7 +774,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
686774
cache_indices,
687775
has_initial_state,
688776
varlen,
689-
pad_slot_id
777+
pad_slot_id,
778+
block_size,
779+
block_idx_first_scheduled_token,
780+
block_idx_last_scheduled_token,
781+
initial_state_idx
690782
);
691783

692784

csrc/ops.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -321,17 +321,19 @@ void dynamic_per_token_scaled_fp8_quant(
321321
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
322322
std::optional<torch::Tensor> const& scale_ub);
323323

324-
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
325-
const torch::Tensor& A, const torch::Tensor& B,
326-
const torch::Tensor& C,
327-
const std::optional<torch::Tensor>& D_,
328-
const std::optional<torch::Tensor>& z_,
329-
const std::optional<torch::Tensor>& delta_bias_,
330-
bool delta_softplus,
331-
const std::optional<torch::Tensor>& query_start_loc,
332-
const std::optional<torch::Tensor>& cache_indices,
333-
const std::optional<torch::Tensor>& has_initial_state,
334-
const torch::Tensor& ssm_states, int64_t pad_slot_id);
324+
void selective_scan_fwd(
325+
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
326+
const torch::Tensor& B, const torch::Tensor& C,
327+
const std::optional<torch::Tensor>& D_,
328+
const std::optional<torch::Tensor>& z_,
329+
const std::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
330+
const std::optional<torch::Tensor>& query_start_loc,
331+
const std::optional<torch::Tensor>& cache_indices,
332+
const std::optional<torch::Tensor>& has_initial_state,
333+
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
334+
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
335+
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
336+
const std::optional<torch::Tensor>& initial_state_idx);
335337

336338
torch::Tensor dynamic_4bit_int_moe_cpu(
337339
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,

csrc/torch_bindings.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
611611
"Tensor? cache_indices,"
612612
"Tensor? has_initial_state,"
613613
"Tensor! ssm_states,"
614-
"int pad_slot_id) -> ()");
614+
"int pad_slot_id,"
615+
"int block_size,"
616+
"Tensor? block_idx_first_scheduled_token,"
617+
"Tensor? block_idx_last_scheduled_token,"
618+
"Tensor? initial_state_idx) -> ()");
615619
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
616620

617621
// Hadamard transforms

tests/kernels/mamba/test_mamba_ssm.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ def selective_scan_opcheck_fn(
179179
has_initial_state=None,
180180
ssm_states=None,
181181
pad_slot_id=PAD_SLOT_ID,
182+
block_size=2048,
183+
block_idx_first_scheduled_token=None,
184+
block_idx_last_scheduled_token=None,
185+
initial_state_idx=None,
182186
):
183187
"""if return_last_state is True, returns (out, last_state)
184188
last_state has shape (batch, dim, dstate).
@@ -223,6 +227,10 @@ def selective_scan_opcheck_fn(
223227
has_initial_state,
224228
ssm_states,
225229
pad_slot_id,
230+
block_size,
231+
block_idx_first_scheduled_token,
232+
block_idx_last_scheduled_token,
233+
initial_state_idx,
226234
),
227235
test_utils=["test_schema", "test_faketensor"],
228236
)
@@ -338,6 +346,11 @@ def test_selective_scan(
338346
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
339347
if c > 0
340348
else None,
349+
pad_slot_id=PAD_SLOT_ID,
350+
block_size=2048,
351+
block_idx_first_scheduled_token=None,
352+
block_idx_last_scheduled_token=None,
353+
initial_state_idx=None,
341354
)
342355
outs.append(out)
343356
if len(outs) > 1:
@@ -372,6 +385,7 @@ def test_selective_scan(
372385
delta_bias=delta_bias,
373386
delta_softplus=delta_softplus,
374387
ssm_states=state,
388+
block_size=2048,
375389
)
376390

377391

@@ -586,6 +600,7 @@ def test_selective_scan_varlen(
586600
padded_state_indices,
587601
has_initial_state,
588602
prev_state,
603+
block_size=2048,
589604
)
590605

591606

0 commit comments

Comments
 (0)