Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/core/radix_sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,8 @@ void radix_sort_rows(int *rows, size_t n, const int *sparsity_IDs,
void parallel_radix_sort_rows(int *rows, size_t n, const int *sparsity_IDs,
const int *coeff_hashes, int *aux);

// Single-key LSD radix sort. Sorts indices by keys[indices[i]]
// ascending (unsigned order). Stable. aux must have space for n ints.
void radix_sort_by_key(int *indices, size_t n, const int *keys, int *aux);

#endif /* RADIX_SORT_H */
92 changes: 92 additions & 0 deletions src/core/radix_sort.c
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,98 @@ void radix_sort_rows(int *rows, size_t n, const int *sparsity_IDs,
}
}

// --- Single-key radix sort ---

static void insertion_sort_by_key(int *indices, size_t n,
const int *keys)
{
for (size_t i = 1; i < n; i++)
{
int idx = indices[i];
int idx_key = keys[idx];
size_t j = i;
while (j > 0 && keys[indices[j - 1]] > idx_key)
{
indices[j] = indices[j - 1];
j--;
}
indices[j] = idx;
}
}

void radix_sort_by_key(int *indices, size_t n,
const int *keys, int *aux)
{
if (n < 256)
{
insertion_sort_by_key(indices, n, keys);
return;
}

size_t counts[256];
int *src = indices, *dst = aux;

for (int pass = 0; pass < 4; pass++)
{
int shift = pass * 8;

// histogram
memset(counts, 0, 256 * sizeof(size_t));
for (size_t i = 0; i < n; i++)
{
unsigned byte =
((uint32_t) keys[src[i]] >> shift) & 0xFF;
counts[byte]++;
}

// skip pass if all values fall in one bucket
if (pass > 0)
{
int skip = 0;
for (int b = 0; b < 256; b++)
{
if (counts[b] == n)
{
skip = 1;
break;
}
}
if (skip)
{
continue;
}
}

// prefix sum
size_t total = 0;
for (int b = 0; b < 256; b++)
{
size_t c = counts[b];
counts[b] = total;
total += c;
}

// scatter (forward — stable)
for (size_t i = 0; i < n; i++)
{
unsigned byte =
((uint32_t) keys[src[i]] >> shift) & 0xFF;
dst[counts[byte]++] = src[i];
}

// swap src and dst
int *tmp = src;
src = dst;
dst = tmp;
}

// if result ended up in aux, copy back to indices
if (src != indices)
{
memcpy(indices, src, n * sizeof(int));
}
}

// --- Parallel radix sort infrastructure ---

typedef struct
Expand Down
31 changes: 5 additions & 26 deletions src/explorers/Primal_propagation.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "SimpleReductions.h"
#include "State.h"
#include "Workspace.h"
#include "radix_sort.h"
#include "utils.h"

static PresolveStatus update_lb_within_propagation(double new_lb, double *lb,
Expand Down Expand Up @@ -782,27 +783,6 @@ PresolveStatus propagate_primal(Problem *prob, bool finite_bound_tightening)
return UNCHANGED;
}

// Static variable to hold col_sizes (for qsort)
static const int *global_col_sizes;

int compare_col_len(const void *a, const void *b)
{
int idx_a = *(const int *) a;
int idx_b = *(const int *) b;

if (global_col_sizes[idx_a] < global_col_sizes[idx_b])
{
return -1;
}
if (global_col_sizes[idx_a] > global_col_sizes[idx_b])
{
return 1;
}

// if equal len, sort by index
return idx_a - idx_b;
}

void remove_redundant_bounds(Constraints *constraints)
{
const Matrix *A = constraints->A;
Expand Down Expand Up @@ -830,9 +810,8 @@ void remove_redundant_bounds(Constraints *constraints)
col_order[ii] = ii;
}

global_col_sizes = col_sizes;
qsort(col_order, n_cols, sizeof(int), compare_col_len);
global_col_sizes = NULL;
int *aux = constraints->state->work->radix_aux;
radix_sort_by_key(col_order, n_cols, col_sizes, aux);

for (jj = 0; jj < n_cols; jj++)
{
Expand Down Expand Up @@ -863,7 +842,7 @@ void remove_redundant_bounds(Constraints *constraints)
bounds))
{
remove_finite_lb_from_activities(&col, acts, bounds[ii].lb);
DEBUG(bounds[ii].lb = -INF);
bounds[ii].lb = -INF;
UPDATE_TAG(col_tags[ii], C_TAG_LB_INF);
}
}
Expand All @@ -875,7 +854,7 @@ void remove_redundant_bounds(Constraints *constraints)
bounds))
{
remove_finite_ub_from_activities(&col, acts, bounds[ii].ub);
DEBUG(bounds[ii].ub = INF);
bounds[ii].ub = INF;
UPDATE_TAG(col_tags[ii], C_TAG_UB_INF);
}
}
Expand Down
141 changes: 141 additions & 0 deletions tests/test_radix_sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,141 @@ static char *test_14_radix_sort()
return 0;
}

// ---- Tests for radix_sort_by_key ----

// helper: check indices sorted ascending by keys[indices[i]]
static int is_sorted_by_single_key(const int *indices, int n, const int *keys)
{
for (int i = 1; i < n; i++)
{
unsigned int prev = (unsigned int) keys[indices[i - 1]];
unsigned int curr = (unsigned int) keys[indices[i]];
if (prev > curr)
{
return 0;
}
}
return 1;
}

/* Test 15: basic ascending sort by single key */
static char *test_15_radix_sort()
{
int n = 5;
int indices[] = {0, 1, 2, 3, 4};
int keys[] = {30, 10, 50, 20, 40};
int aux[5];

radix_sort_by_key(indices, (size_t) n, keys, aux);
mu_assert("should be sorted by key", is_sorted_by_single_key(indices, n, keys));
mu_assert("should be permutation", is_permutation(indices, n));
// expected: 1(10), 3(20), 0(30), 4(40), 2(50)
mu_assert("pos 0", indices[0] == 1);
mu_assert("pos 1", indices[1] == 3);
mu_assert("pos 2", indices[2] == 0);
mu_assert("pos 3", indices[3] == 4);
mu_assert("pos 4", indices[4] == 2);
return 0;
}

/* Test 16: already-sorted input */
static char *test_16_radix_sort()
{
int n = 5;
int indices[] = {0, 1, 2, 3, 4};
int keys[] = {1, 2, 3, 4, 5};
int aux[5];

radix_sort_by_key(indices, (size_t) n, keys, aux);
mu_assert("should be sorted", is_sorted_by_single_key(indices, n, keys));
for (int i = 0; i < n; i++)
{
mu_assert("order preserved", indices[i] == i);
}
return 0;
}

/* Test 17: reverse-sorted input */
static char *test_17_radix_sort()
{
int n = 5;
int indices[] = {0, 1, 2, 3, 4};
int keys[] = {50, 40, 30, 20, 10};
int aux[5];

radix_sort_by_key(indices, (size_t) n, keys, aux);
mu_assert("should be sorted", is_sorted_by_single_key(indices, n, keys));
mu_assert("first", indices[0] == 4);
mu_assert("last", indices[4] == 0);
return 0;
}

/* Test 18: duplicate keys — stability preserves input order */
static char *test_18_radix_sort()
{
int n = 6;
int indices[] = {0, 1, 2, 3, 4, 5};
int keys[] = {5, 5, 5, 10, 10, 10};
int aux[6];

radix_sort_by_key(indices, (size_t) n, keys, aux);
mu_assert("should be sorted", is_sorted_by_single_key(indices, n, keys));
mu_assert("should be permutation", is_permutation(indices, n));
// stable: group key=5 keeps order 0,1,2; group key=10 keeps 3,4,5
mu_assert("stable 0", indices[0] == 0);
mu_assert("stable 1", indices[1] == 1);
mu_assert("stable 2", indices[2] == 2);
mu_assert("stable 3", indices[3] == 3);
mu_assert("stable 4", indices[4] == 4);
mu_assert("stable 5", indices[5] == 5);
return 0;
}

/* Test 19: large input (n=300) forces radix path */
static char *test_19_radix_sort()
{
int n = 300;
int indices[300];
int keys[300];
int aux[300];

srand(54321);
for (int i = 0; i < n; i++)
{
indices[i] = i;
keys[i] = rand() % 100;
}

radix_sort_by_key(indices, (size_t) n, keys, aux);
mu_assert("should be sorted", is_sorted_by_single_key(indices, n, keys));
mu_assert("should be permutation", is_permutation(indices, n));
return 0;
}

/* Test 20: large input with all identical keys (skip-pass opt) */
static char *test_20_radix_sort()
{
int n = 300;
int indices[300];
int keys[300];
int aux[300];

for (int i = 0; i < n; i++)
{
indices[i] = i;
keys[i] = 42;
}

radix_sort_by_key(indices, (size_t) n, keys, aux);
mu_assert("should be sorted", is_sorted_by_single_key(indices, n, keys));
// stable: identical keys preserve input order 0,1,...,n-1
for (int i = 0; i < n; i++)
{
mu_assert("stable order", indices[i] == i);
}
return 0;
}

static const char *all_tests_radix_sort()
{
mu_run_test(test_1_radix_sort, counter_radix_sort);
Expand All @@ -448,6 +583,12 @@ static const char *all_tests_radix_sort()
mu_run_test(test_12_radix_sort, counter_radix_sort);
mu_run_test(test_13_radix_sort, counter_radix_sort);
mu_run_test(test_14_radix_sort, counter_radix_sort);
mu_run_test(test_15_radix_sort, counter_radix_sort);
mu_run_test(test_16_radix_sort, counter_radix_sort);
mu_run_test(test_17_radix_sort, counter_radix_sort);
mu_run_test(test_18_radix_sort, counter_radix_sort);
mu_run_test(test_19_radix_sort, counter_radix_sort);
mu_run_test(test_20_radix_sort, counter_radix_sort);

return 0;
}
Expand Down