@@ -41,6 +41,14 @@ def get_bubble_sort_ptr(dtype: str) -> int:
4141
4242 return _materialize (dtype )
4343
44+
45+ def get_quick_sort_ptr (dtype : str ) -> int :
46+ dtype = dtype .lower ().strip ()
47+ if dtype not in _SUPPORTED :
48+ raise ValueError (f"Unsupported dtype '{ dtype } '. Supported: { list (_SUPPORTED )} " )
49+
50+ return _materialize_quick (dtype )
51+
4452def _build_bubble_sort_ir (dtype : str ) -> str :
4553 if dtype not in _SUPPORTED :
4654 raise ValueError (f"Unsupported dtype '{ dtype } '. Supported: { list (_SUPPORTED )} " )
@@ -131,6 +139,134 @@ def _build_bubble_sort_ir(dtype: str) -> str:
131139
132140 return str (mod )
133141
142+
143+ def _build_quick_sort_ir (dtype : str ) -> str :
144+ if dtype not in _SUPPORTED :
145+ raise ValueError (f"Unsupported dtype '{ dtype } '. Supported: { list (_SUPPORTED )} " )
146+
147+ T , _ = _SUPPORTED [dtype ]
148+ i32 = ir .IntType (32 )
149+ i64 = ir .IntType (64 )
150+
151+ mod = ir .Module (name = f"quick_sort_{ dtype } _module" )
152+ fn_name = f"quick_sort_{ dtype } "
153+
154+ # void quick_sort(T* arr, int32 low, int32 high)
155+ fn_ty = ir .FunctionType (ir .VoidType (), [T .as_pointer (), i32 , i32 ])
156+ fn = ir .Function (mod , fn_ty , name = fn_name )
157+ arr , low , high = fn .args
158+ arr .name , low .name , high .name = "arr" , "low" , "high"
159+
160+ entry = fn .append_basic_block ("entry" )
161+ part = fn .append_basic_block ("partition" )
162+ exit = fn .append_basic_block ("exit" )
163+
164+ b = ir .IRBuilder (entry )
165+
166+ # if (low < high)
167+ cond = b .icmp_signed ("<" , low , high )
168+ b .cbranch (cond , part , exit )
169+
170+ # --- Partition block
171+ b .position_at_end (part )
172+
173+ # pivot = arr[high]
174+ high_64 = b .sext (high , i64 )
175+ pivot_ptr = b .gep (arr , [high_64 ])
176+ pivot = b .load (pivot_ptr , name = "pivot" )
177+
178+ # i = low - 1
179+ i = b .alloca (i32 , name = "i" )
180+ i_init = b .sub (low , ir .Constant (i32 , 1 ))
181+ b .store (i_init , i )
182+
183+ # j = low
184+ j = b .alloca (i32 , name = "j" )
185+ b .store (low , j )
186+
187+ loop = fn .append_basic_block ("loop" )
188+ after_loop = fn .append_basic_block ("after_loop" )
189+ body = fn .append_basic_block ("body" )
190+ swap = fn .append_basic_block ("swap" )
191+ skip_swap = fn .append_basic_block ("skip_swap" )
192+
193+ b .branch (loop )
194+
195+ # --- Loop: while (j < high)
196+ b .position_at_end (loop )
197+ j_val = b .load (j )
198+ cond = b .icmp_signed ("<" , j_val , high )
199+ b .cbranch (cond , body , after_loop )
200+
201+ # --- Body
202+ b .position_at_end (body )
203+ j64 = b .sext (j_val , i64 )
204+ elem_ptr = b .gep (arr , [j64 ])
205+ elem = b .load (elem_ptr , name = "elem" )
206+
207+ if isinstance (T , ir .IntType ):
208+ cmp = b .icmp_signed ("<=" , elem , pivot )
209+ else :
210+ cmp = b .fcmp_ordered ("<=" , elem , pivot , fastmath = True )
211+
212+ b .cbranch (cmp , swap , skip_swap )
213+
214+ # --- Swap block
215+ b .position_at_end (swap )
216+ i_val = b .load (i )
217+ i_next = b .add (i_val , ir .Constant (i32 , 1 ))
218+ b .store (i_next , i )
219+
220+ i64v = b .sext (i_next , i64 )
221+ iptr = b .gep (arr , [i64v ])
222+ ival = b .load (iptr )
223+ # swap arr[i] and arr[j]
224+ b .store (elem , iptr )
225+ b .store (ival , elem_ptr )
226+
227+ b .branch (skip_swap )
228+
229+ # --- Skip swap
230+ b .position_at_end (skip_swap )
231+ j_next = b .add (j_val , ir .Constant (i32 , 1 ))
232+ b .store (j_next , j )
233+ b .branch (loop )
234+
235+ # --- After loop
236+ b .position_at_end (after_loop )
237+ i_val = b .load (i )
238+ i_plus1 = b .add (i_val , ir .Constant (i32 , 1 ))
239+
240+ i64v = b .sext (i_plus1 , i64 )
241+ iptr = b .gep (arr , [i64v ])
242+ ival = b .load (iptr )
243+
244+ # swap arr[i+1] and arr[high]
245+ b .store (pivot , iptr )
246+ b .store (ival , pivot_ptr )
247+
248+ # Now i+1 is the partition index
249+ pi = i_plus1
250+
251+ # Recursive calls:
252+ # quick_sort(arr, low, pi - 1)
253+ low_call = low
254+ high_call1 = b .sub (pi , ir .Constant (i32 , 1 ))
255+ b .call (fn , [arr , low_call , high_call1 ])
256+
257+ # quick_sort(arr, pi + 1, high)
258+ low_call2 = b .add (pi , ir .Constant (i32 , 1 ))
259+ high_call2 = high
260+ b .call (fn , [arr , low_call2 , high_call2 ])
261+
262+ b .branch (exit )
263+
264+ # --- Exit
265+ b .position_at_end (exit )
266+ b .ret_void ()
267+
268+ return str (mod )
269+
134270def _materialize (dtype : str ) -> int :
135271 _ensure_target_machine ()
136272
@@ -167,3 +303,42 @@ def _materialize(dtype: str) -> int:
167303
168304 except Exception as e :
169305 raise RuntimeError (f"Failed to materialize function for dtype { dtype } : { e } " )
306+
307+
308+ def _materialize_quick (dtype : str ) -> int :
309+ _ensure_target_machine ()
310+
311+ key = f"quick_{ dtype } "
312+ if key in _fn_ptr_cache :
313+ return _fn_ptr_cache [key ]
314+
315+ try :
316+ llvm_ir = _build_quick_sort_ir (dtype )
317+ mod = binding .parse_assembly (llvm_ir )
318+ mod .verify ()
319+
320+ try :
321+ pm = binding .ModulePassManager ()
322+ pm .add_instruction_combining_pass ()
323+ pm .add_reassociate_pass ()
324+ pm .add_gvn_pass ()
325+ pm .add_cfg_simplification_pass ()
326+ pm .run (mod )
327+ except AttributeError :
328+ pass
329+
330+ engine = binding .create_mcjit_compiler (mod , _target_machine )
331+ engine .finalize_object ()
332+ engine .run_static_constructors ()
333+
334+ addr = engine .get_function_address (f"quick_sort_{ dtype } " )
335+ if not addr :
336+ raise RuntimeError (f"Failed to get address for quick_sort_{ dtype } " )
337+
338+ _fn_ptr_cache [key ] = addr
339+ _engines [key ] = engine
340+
341+ return addr
342+
343+ except Exception as e :
344+ raise RuntimeError (f"Failed to materialize quick sort function for dtype { dtype } : { e } " )
0 commit comments