Skip to content

Commit f9ce8c1

Browse files
add llvm implementation for insertion sort
1 parent fc75883 commit f9ce8c1

File tree

5 files changed

+669
-0
lines changed

5 files changed

+669
-0
lines changed

pydatastructs/linear_data_structures/_backend/cpp/algorithms/algorithms.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ static PyMethodDef algorithms_PyMethodDef[] = {
1414
METH_VARARGS | METH_KEYWORDS, ""},
1515
{"insertion_sort", (PyCFunction) insertion_sort,
1616
METH_VARARGS | METH_KEYWORDS, ""},
17+
{"insertion_sort_llvm", (PyCFunction)insertion_sort_llvm,
18+
METH_VARARGS | METH_KEYWORDS, ""},
1719
{"is_ordered", (PyCFunction) is_ordered,
1820
METH_VARARGS | METH_KEYWORDS, ""},
1921
{"linear_search", (PyCFunction) linear_search,

pydatastructs/linear_data_structures/_backend/cpp/algorithms/llvm_algorithms.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,122 @@ def _materialize(dtype: str) -> int:
167167

168168
except Exception as e:
169169
raise RuntimeError(f"Failed to materialize function for dtype {dtype}: {e}")
170+
171+
172+
def get_insertion_sort_ptr(dtype: str) -> int:
173+
"""Get function pointer for insertion sort with specified dtype."""
174+
dtype = dtype.lower().strip()
175+
if dtype not in _SUPPORTED:
176+
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
177+
178+
return _materialize_insertion(dtype)
179+
180+
181+
def _build_insertion_sort_ir(dtype: str) -> str:
182+
if dtype not in _SUPPORTED:
183+
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
184+
185+
T, _ = _SUPPORTED[dtype]
186+
i32 = ir.IntType(32)
187+
i64 = ir.IntType(64)
188+
189+
mod = ir.Module(name=f"insertion_sort_{dtype}_module")
190+
fn_name = f"insertion_sort_{dtype}"
191+
192+
fn_ty = ir.FunctionType(ir.VoidType(), [T.as_pointer(), i32])
193+
fn = ir.Function(mod, fn_ty, name=fn_name)
194+
195+
arr, n = fn.args
196+
arr.name, n.name = "arr", "n"
197+
198+
b_entry = fn.append_basic_block("entry")
199+
b_outer = fn.append_basic_block("outer")
200+
b_inner = fn.append_basic_block("inner")
201+
b_inner_latch = fn.append_basic_block("inner.latch")
202+
b_exit = fn.append_basic_block("exit")
203+
204+
b = ir.IRBuilder(b_entry)
205+
cond_trivial = b.icmp_signed("<=", n, ir.Constant(i32, 1))
206+
b.cbranch(cond_trivial, b_exit, b_outer)
207+
208+
b.position_at_end(b_outer)
209+
i_phi = b.phi(i32, name="i")
210+
i_phi.add_incoming(ir.Constant(i32, 1), b_entry) # start at 1
211+
212+
cond_outer = b.icmp_signed("<", i_phi, n)
213+
b.cbranch(cond_outer, b_inner, b_exit)
214+
215+
b.position_at_end(b_inner)
216+
# key = arr[i]
217+
i64_idx = b.sext(i_phi, i64)
218+
key_ptr = b.gep(arr, [i64_idx], inbounds=True)
219+
key_val = b.load(key_ptr)
220+
221+
# j = i - 1
222+
j = b.sub(i_phi, ir.Constant(i32, 1))
223+
j64 = b.sext(j, i64)
224+
225+
b_inner_loop = fn.append_basic_block("inner.loop")
226+
b.position_at_end(b_inner)
227+
b.branch(b_inner_loop)
228+
229+
b.position_at_end(b_inner_loop)
230+
cond_j = b.icmp_signed(">=", j, ir.Constant(i32, 0))
231+
b.cbranch(cond_j, b_inner_latch, b_outer_latch)
232+
233+
b.position_at_end(b_inner_latch)
234+
j64 = b.sext(j, i64)
235+
arr_j_ptr = b.gep(arr, [j64], inbounds=True)
236+
arr_j_val = b.load(arr_j_ptr)
237+
238+
# compare arr[j] > key
239+
if isinstance(T, ir.IntType):
240+
cmp = b.icmp_signed(">", arr_j_val, key_val)
241+
else:
242+
cmp = b.fcmp_ordered(">", arr_j_val, key_val)
243+
244+
b.cbranch(cmp, b_inner_latch, b_outer_latch)
245+
246+
# swap/move
247+
b.store(arr_j_val, b.gep(arr, [j64 + ir.Constant(i64, 1)], inbounds=True))
248+
j = b.sub(j, ir.Constant(i32, 1))
249+
b.branch(b_inner_loop)
250+
251+
b.position_at_end(b_outer_latch)
252+
b.store(key_val, b.gep(arr, [b.sext(j+ir.Constant(i32,1), i64)], inbounds=True))
253+
254+
i_next = b.add(i_phi, ir.Constant(i32, 1))
255+
i_phi.add_incoming(i_next, b_outer_latch)
256+
b.branch(b_outer)
257+
258+
b.position_at_end(b_exit)
259+
b.ret_void()
260+
261+
return str(mod)
262+
263+
264+
def _materialize_insertion(dtype: str) -> int:
265+
_ensure_target_machine()
266+
267+
name = f"insertion_sort_{dtype}"
268+
if dtype in _fn_ptr_cache:
269+
return _fn_ptr_cache[dtype]
270+
271+
try:
272+
llvm_ir = _build_insertion_sort_ir(dtype)
273+
mod = binding.parse_assembly(llvm_ir)
274+
mod.verify()
275+
276+
engine = binding.create_mcjit_compiler(mod, _target_machine)
277+
engine.finalize_object()
278+
engine.run_static_constructors()
279+
280+
addr = engine.get_function_address(name)
281+
if not addr:
282+
raise RuntimeError(f"Failed to get address for {name}")
283+
284+
_fn_ptr_cache[dtype] = addr
285+
_engines[dtype] = engine
286+
return addr
287+
except Exception as e:
288+
raise RuntimeError(f"Failed to materialize function for dtype {dtype}: {e}")

0 commit comments

Comments
 (0)