Skip to content

Commit bfd3be7

Browse files
feat: add llvm implementation of quick_sort
1 parent fc75883 commit bfd3be7

File tree

6 files changed

+726
-0
lines changed

6 files changed

+726
-0
lines changed

pydatastructs/linear_data_structures/_backend/cpp/algorithms/algorithms.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
static PyMethodDef algorithms_PyMethodDef[] = {
77
{"quick_sort", (PyCFunction) quick_sort,
88
METH_VARARGS | METH_KEYWORDS, ""},
9+
{"quick_sort_llvm", (PyCFunction)quick_sort_llvm,
10+
METH_VARARGS | METH_KEYWORDS, ""},
911
{"bubble_sort", (PyCFunction) bubble_sort,
1012
METH_VARARGS | METH_KEYWORDS, ""},
1113
{"bubble_sort_llvm", (PyCFunction)bubble_sort_llvm,

pydatastructs/linear_data_structures/_backend/cpp/algorithms/llvm_algorithms.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ def get_bubble_sort_ptr(dtype: str) -> int:
4141

4242
return _materialize(dtype)
4343

44+
45+
def get_quick_sort_ptr(dtype: str) -> int:
46+
dtype = dtype.lower().strip()
47+
if dtype not in _SUPPORTED:
48+
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
49+
50+
return _materialize_quick(dtype)
51+
4452
def _build_bubble_sort_ir(dtype: str) -> str:
4553
if dtype not in _SUPPORTED:
4654
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
@@ -131,6 +139,134 @@ def _build_bubble_sort_ir(dtype: str) -> str:
131139

132140
return str(mod)
133141

142+
143+
def _build_quick_sort_ir(dtype: str) -> str:
144+
if dtype not in _SUPPORTED:
145+
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
146+
147+
T, _ = _SUPPORTED[dtype]
148+
i32 = ir.IntType(32)
149+
i64 = ir.IntType(64)
150+
151+
mod = ir.Module(name=f"quick_sort_{dtype}_module")
152+
fn_name = f"quick_sort_{dtype}"
153+
154+
# void quick_sort(T* arr, int32 low, int32 high)
155+
fn_ty = ir.FunctionType(ir.VoidType(), [T.as_pointer(), i32, i32])
156+
fn = ir.Function(mod, fn_ty, name=fn_name)
157+
arr, low, high = fn.args
158+
arr.name, low.name, high.name = "arr", "low", "high"
159+
160+
entry = fn.append_basic_block("entry")
161+
part = fn.append_basic_block("partition")
162+
exit = fn.append_basic_block("exit")
163+
164+
b = ir.IRBuilder(entry)
165+
166+
# if (low < high)
167+
cond = b.icmp_signed("<", low, high)
168+
b.cbranch(cond, part, exit)
169+
170+
# --- Partition block
171+
b.position_at_end(part)
172+
173+
# pivot = arr[high]
174+
high_64 = b.sext(high, i64)
175+
pivot_ptr = b.gep(arr, [high_64])
176+
pivot = b.load(pivot_ptr, name="pivot")
177+
178+
# i = low - 1
179+
i = b.alloca(i32, name="i")
180+
i_init = b.sub(low, ir.Constant(i32, 1))
181+
b.store(i_init, i)
182+
183+
# j = low
184+
j = b.alloca(i32, name="j")
185+
b.store(low, j)
186+
187+
loop = fn.append_basic_block("loop")
188+
after_loop = fn.append_basic_block("after_loop")
189+
body = fn.append_basic_block("body")
190+
swap = fn.append_basic_block("swap")
191+
skip_swap = fn.append_basic_block("skip_swap")
192+
193+
b.branch(loop)
194+
195+
# --- Loop: while (j < high)
196+
b.position_at_end(loop)
197+
j_val = b.load(j)
198+
cond = b.icmp_signed("<", j_val, high)
199+
b.cbranch(cond, body, after_loop)
200+
201+
# --- Body
202+
b.position_at_end(body)
203+
j64 = b.sext(j_val, i64)
204+
elem_ptr = b.gep(arr, [j64])
205+
elem = b.load(elem_ptr, name="elem")
206+
207+
if isinstance(T, ir.IntType):
208+
cmp = b.icmp_signed("<=", elem, pivot)
209+
else:
210+
cmp = b.fcmp_ordered("<=", elem, pivot, fastmath=True)
211+
212+
b.cbranch(cmp, swap, skip_swap)
213+
214+
# --- Swap block
215+
b.position_at_end(swap)
216+
i_val = b.load(i)
217+
i_next = b.add(i_val, ir.Constant(i32, 1))
218+
b.store(i_next, i)
219+
220+
i64v = b.sext(i_next, i64)
221+
iptr = b.gep(arr, [i64v])
222+
ival = b.load(iptr)
223+
# swap arr[i] and arr[j]
224+
b.store(elem, iptr)
225+
b.store(ival, elem_ptr)
226+
227+
b.branch(skip_swap)
228+
229+
# --- Skip swap
230+
b.position_at_end(skip_swap)
231+
j_next = b.add(j_val, ir.Constant(i32, 1))
232+
b.store(j_next, j)
233+
b.branch(loop)
234+
235+
# --- After loop
236+
b.position_at_end(after_loop)
237+
i_val = b.load(i)
238+
i_plus1 = b.add(i_val, ir.Constant(i32, 1))
239+
240+
i64v = b.sext(i_plus1, i64)
241+
iptr = b.gep(arr, [i64v])
242+
ival = b.load(iptr)
243+
244+
# swap arr[i+1] and arr[high]
245+
b.store(pivot, iptr)
246+
b.store(ival, pivot_ptr)
247+
248+
# Now i+1 is the partition index
249+
pi = i_plus1
250+
251+
# Recursive calls:
252+
# quick_sort(arr, low, pi - 1)
253+
low_call = low
254+
high_call1 = b.sub(pi, ir.Constant(i32, 1))
255+
b.call(fn, [arr, low_call, high_call1])
256+
257+
# quick_sort(arr, pi + 1, high)
258+
low_call2 = b.add(pi, ir.Constant(i32, 1))
259+
high_call2 = high
260+
b.call(fn, [arr, low_call2, high_call2])
261+
262+
b.branch(exit)
263+
264+
# --- Exit
265+
b.position_at_end(exit)
266+
b.ret_void()
267+
268+
return str(mod)
269+
134270
def _materialize(dtype: str) -> int:
135271
_ensure_target_machine()
136272

@@ -167,3 +303,42 @@ def _materialize(dtype: str) -> int:
167303

168304
except Exception as e:
169305
raise RuntimeError(f"Failed to materialize function for dtype {dtype}: {e}")
306+
307+
308+
def _materialize_quick(dtype: str) -> int:
309+
_ensure_target_machine()
310+
311+
key = f"quick_{dtype}"
312+
if key in _fn_ptr_cache:
313+
return _fn_ptr_cache[key]
314+
315+
try:
316+
llvm_ir = _build_quick_sort_ir(dtype)
317+
mod = binding.parse_assembly(llvm_ir)
318+
mod.verify()
319+
320+
try:
321+
pm = binding.ModulePassManager()
322+
pm.add_instruction_combining_pass()
323+
pm.add_reassociate_pass()
324+
pm.add_gvn_pass()
325+
pm.add_cfg_simplification_pass()
326+
pm.run(mod)
327+
except AttributeError:
328+
pass
329+
330+
engine = binding.create_mcjit_compiler(mod, _target_machine)
331+
engine.finalize_object()
332+
engine.run_static_constructors()
333+
334+
addr = engine.get_function_address(f"quick_sort_{dtype}")
335+
if not addr:
336+
raise RuntimeError(f"Failed to get address for quick_sort_{dtype}")
337+
338+
_fn_ptr_cache[key] = addr
339+
_engines[key] = engine
340+
341+
return addr
342+
343+
except Exception as e:
344+
raise RuntimeError(f"Failed to materialize quick sort function for dtype {dtype}: {e}")

pydatastructs/linear_data_structures/_backend/cpp/algorithms/quadratic_time_sort.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ static PyObject* bubble_sort_llvm(PyObject* self, PyObject* args, PyObject* kwds
612612
Py_INCREF(arr_obj);
613613
return arr_obj;
614614
}
615+
615616
// Selection Sort
616617
static PyObject* selection_sort_impl(PyObject* array, size_t lower, size_t upper,
617618
PyObject* comp) {

0 commit comments

Comments
 (0)