55import mlir .execution_engine
66import mlir .passmanager
77from mlir import ir
8+ from mlir import runtime as rt
89from mlir .dialects import arith , bufferization , func , sparse_tensor , tensor
910
1011import numpy as np
1314from ._common import fn_cache
1415from ._core import CWD , DEBUG , MLIR_C_RUNNER_UTILS , ctx
1516from ._dtypes import DType , Index , asdtype
16- from ._memref import make_memref_ctype , ranked_memref_from_np
1717
1818
1919def _hold_self_ref_in_ret (fn ):
@@ -108,7 +108,7 @@ def free_tensor(tensor_shaped):
108108 @classmethod
109109 def assemble (cls , module , arr : np .ndarray ) -> ctypes .c_void_p :
110110 assert arr .ndim == 2
111- data = ranked_memref_from_np (arr .flatten ())
111+ data = rt . get_ranked_memref_descriptor (arr .flatten ())
112112 out = ctypes .c_void_p ()
113113 module .invoke (
114114 "assemble" ,
@@ -121,14 +121,14 @@ def assemble(cls, module, arr: np.ndarray) -> ctypes.c_void_p:
121121 def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p , dtype : type [DType ]) -> np .ndarray :
122122 class Dense (ctypes .Structure ):
123123 _fields_ = [
124- ("data" , make_memref_ctype ( dtype , 1 )),
124+ ("data" , rt . make_nd_memref_descriptor ( 1 , dtype . to_ctype () )),
125125 ("data_len" , np .ctypeslib .c_intp ),
126126 ("shape_x" , np .ctypeslib .c_intp ),
127127 ("shape_y" , np .ctypeslib .c_intp ),
128128 ]
129129
130130 def to_np (self ) -> np .ndarray :
131- data = self .data . to_numpy ( )[: self .data_len ]
131+ data = rt . ranked_memref_to_numpy ([ self .data ] )[: self .data_len ]
132132 return data .reshape ((self .shape_x , self .shape_y ))
133133
134134 arr = Dense ()
@@ -141,8 +141,107 @@ def to_np(self) -> np.ndarray:
141141
142142
143143class COOFormat :
144- # TODO: implement
145- ...
144+ @fn_cache
145+ def get_module (shape : tuple [int ], values_dtype : type [DType ], index_dtype : type [DType ]):
146+ with ir .Location .unknown (ctx ):
147+ module = ir .Module .create ()
148+ values_dtype = values_dtype .get_mlir_type ()
149+ index_dtype = index_dtype .get_mlir_type ()
150+ index_width = getattr (index_dtype , "width" , 0 )
151+ compressed_lvl = sparse_tensor .EncodingAttr .build_level_type (
152+ sparse_tensor .LevelFormat .compressed , [sparse_tensor .LevelProperty .non_unique ]
153+ )
154+ levels = (compressed_lvl , sparse_tensor .LevelFormat .singleton )
155+ ordering = ir .AffineMap .get_permutation ([0 , 1 ])
156+ encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
157+ coo_shaped = ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
158+
159+ tensor_1d_index = tensor .RankedTensorType .get ([ir .ShapedType .get_dynamic_size ()], index_dtype )
160+ tensor_2d_index = tensor .RankedTensorType .get ([ir .ShapedType .get_dynamic_size (), len (shape )], index_dtype )
161+ tensor_1d_values = tensor .RankedTensorType .get ([ir .ShapedType .get_dynamic_size ()], values_dtype )
162+
163+ with ir .InsertionPoint (module .body ):
164+
165+ @func .FuncOp .from_py_func (tensor_1d_index , tensor_2d_index , tensor_1d_values )
166+ def assemble (pos , index , values ):
167+ return sparse_tensor .assemble (coo_shaped , (pos , index ), values )
168+
169+ @func .FuncOp .from_py_func (coo_shaped )
170+ def disassemble (tensor_shaped ):
171+ nse = sparse_tensor .number_of_entries (tensor_shaped )
172+ pos = tensor .EmptyOp ([arith .constant (ir .IndexType .get (), 2 )], index_dtype )
173+ index = tensor .EmptyOp ([nse , 2 ], index_dtype )
174+ values = tensor .EmptyOp ([nse ], values_dtype )
175+ pos , index , values , pos_len , index_len , values_len = sparse_tensor .disassemble (
176+ (tensor_1d_index , tensor_2d_index ),
177+ tensor_1d_values ,
178+ (index_dtype , index_dtype ),
179+ index_dtype ,
180+ tensor_shaped ,
181+ (pos , index ),
182+ values ,
183+ )
184+ shape_consts = [arith .constant (index_dtype , s ) for s in shape ]
185+ return pos , index , values , pos_len , index_len , values_len , * shape_consts
186+
187+ @func .FuncOp .from_py_func (coo_shaped )
188+ def free_tensor (tensor_shaped ):
189+ bufferization .dealloc_tensor (tensor_shaped )
190+
191+ assemble .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
192+ disassemble .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
193+ free_tensor .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
194+ if DEBUG :
195+ (CWD / "coo_module.mlir" ).write_text (str (module ))
196+ pm = mlir .passmanager .PassManager .parse ("builtin.module(sparsifier{create-sparse-deallocs=1})" )
197+ pm .run (module .operation )
198+ if DEBUG :
199+ (CWD / "coo_module_opt.mlir" ).write_text (str (module ))
200+
201+ module = mlir .execution_engine .ExecutionEngine (module , opt_level = 2 , shared_libs = [MLIR_C_RUNNER_UTILS ])
202+ return (module , coo_shaped )
203+
204+ @classmethod
205+ def assemble (cls , module : ir .Module , arr : sps .coo_array ) -> ctypes .c_void_p :
206+ out = ctypes .c_void_p ()
207+ module .invoke (
208+ "assemble" ,
209+ ctypes .pointer (
210+ ctypes .pointer (rt .get_ranked_memref_descriptor (np .array ([0 , arr .size ], dtype = arr .coords [0 ].dtype )))
211+ ),
212+ ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (np .stack (arr .coords , axis = 1 )))),
213+ ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (arr .data ))),
214+ ctypes .pointer (out ),
215+ )
216+ return out
217+
218+ @classmethod
219+ def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p , dtype : type [DType ]) -> sps .coo_array :
220+ class Coo (ctypes .Structure ):
221+ _fields_ = [
222+ ("pos" , rt .make_nd_memref_descriptor (1 , Index .to_ctype ())),
223+ ("index" , rt .make_nd_memref_descriptor (2 , Index .to_ctype ())),
224+ ("values" , rt .make_nd_memref_descriptor (1 , dtype .to_ctype ())),
225+ ("pos_len" , np .ctypeslib .c_intp ),
226+ ("index_len" , np .ctypeslib .c_intp ),
227+ ("values_len" , np .ctypeslib .c_intp ),
228+ ("shape_x" , np .ctypeslib .c_intp ),
229+ ("shape_y" , np .ctypeslib .c_intp ),
230+ ]
231+
232+ def to_sps (self ) -> sps .coo_array :
233+ pos = rt .ranked_memref_to_numpy ([self .pos ])[: self .pos_len ]
234+ index = rt .ranked_memref_to_numpy ([self .index ])[pos [0 ] : pos [1 ]]
235+ values = rt .ranked_memref_to_numpy ([self .values ])[: self .values_len ]
236+ return sps .coo_array ((values , index .T ), shape = (self .shape_x , self .shape_y ))
237+
238+ arr = Coo ()
239+ module .invoke (
240+ "disassemble" ,
241+ ctypes .pointer (ctypes .pointer (arr )),
242+ ctypes .pointer (ptr ),
243+ )
244+ return arr .to_sps ()
146245
147246
148247class CSRFormat :
@@ -207,9 +306,9 @@ def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
207306 out = ctypes .c_void_p ()
208307 module .invoke (
209308 "assemble" ,
210- ctypes .pointer (ctypes .pointer (ranked_memref_from_np (arr .indptr ))),
211- ctypes .pointer (ctypes .pointer (ranked_memref_from_np (arr .indices ))),
212- ctypes .pointer (ctypes .pointer (ranked_memref_from_np (arr .data ))),
309+ ctypes .pointer (ctypes .pointer (rt . get_ranked_memref_descriptor (arr .indptr ))),
310+ ctypes .pointer (ctypes .pointer (rt . get_ranked_memref_descriptor (arr .indices ))),
311+ ctypes .pointer (ctypes .pointer (rt . get_ranked_memref_descriptor (arr .data ))),
213312 ctypes .pointer (out ),
214313 )
215314 return out
@@ -218,9 +317,9 @@ def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
218317 def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p , dtype : type [DType ]) -> sps .csr_array :
219318 class Csr (ctypes .Structure ):
220319 _fields_ = [
221- ("pos" , make_memref_ctype ( Index , 1 )),
222- ("crd" , make_memref_ctype ( Index , 1 )),
223- ("data" , make_memref_ctype ( dtype , 1 )),
320+ ("pos" , rt . make_nd_memref_descriptor ( 1 , Index . to_ctype () )),
321+ ("crd" , rt . make_nd_memref_descriptor ( 1 , Index . to_ctype () )),
322+ ("data" , rt . make_nd_memref_descriptor ( 1 , dtype . to_ctype () )),
224323 ("pos_len" , np .ctypeslib .c_intp ),
225324 ("crd_len" , np .ctypeslib .c_intp ),
226325 ("data_len" , np .ctypeslib .c_intp ),
@@ -229,9 +328,9 @@ class Csr(ctypes.Structure):
229328 ]
230329
231330 def to_sps (self ) -> sps .csr_array :
232- pos = self .pos . to_numpy ( )[: self .pos_len ]
233- crd = self .crd . to_numpy ( )[: self .crd_len ]
234- data = self .data . to_numpy ( )[: self .data_len ]
331+ pos = rt . ranked_memref_to_numpy ([ self .pos ] )[: self .pos_len ]
332+ crd = rt . ranked_memref_to_numpy ([ self .crd ] )[: self .crd_len ]
333+ data = rt . ranked_memref_to_numpy ([ self .data ] )[: self .data_len ]
235334 return sps .csr_array ((data , crd , pos ), shape = (self .shape_x , self .shape_y ))
236335
237336 arr = Csr ()
@@ -257,9 +356,16 @@ def asarray(obj) -> Tensor:
257356
258357 # TODO: support other scipy formats
259358 if _is_scipy_sparse_obj (obj ):
260- format_class = CSRFormat
261- # This can be int32 or int64
262- index_dtype = asdtype (obj .indptr .dtype )
359+ if obj .format == "csr" :
360+ format_class = CSRFormat
361+ # This can be int32 or int64
362+ index_dtype = asdtype (obj .indptr .dtype )
363+ elif obj .format == "coo" :
364+ format_class = COOFormat
365+ # This can be int32 or int64
366+ index_dtype = asdtype (obj .coords [0 ].dtype )
367+ else :
368+ raise Exception (f"{ obj .format } SciPy format not supported." )
263369 elif _is_numpy_obj (obj ):
264370 format_class = DenseFormat
265371 index_dtype = Index
0 commit comments