@@ -517,6 +517,37 @@ def ccallback(f):
517517
518518class jit :
519519 def __init__ (self , function ):
520+ def get_rtlib_dir ():
521+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
522+ return os .path .join (current_dir , ".." )
523+
524+ def get_type_info (arg ):
525+ # return_type -> (`type_format`, `variable type`, `array struct name`)
526+ # See: https://docs.python.org/3/c-api/arg.html for more info on type_format
527+ if arg == f64 :
528+ return ('d' , "double" , 'r64' )
529+ elif arg == f32 :
530+ return ('f' , "float" , 'r32' )
531+ elif arg == i64 :
532+ return ('l' , "long int" , 'i64' )
533+ elif arg == i32 :
534+ return ('i' , "int" , 'i32' )
535+ elif arg == bool :
536+ return ('p' , "bool" , '' )
537+ elif isinstance (arg , Array ):
538+ t = get_type_info (arg ._type )
539+ if t [2 ] == '' :
540+ raise NotImplementedError ("Type %r not implemented" % arg )
541+ return ('O' , ["PyArrayObject *" , "struct " + t [2 ]+ " *" , t [1 ]+ " *" ], '' )
542+ else :
543+ raise NotImplementedError ("Type %r not implemented" % arg )
544+
545+ def get_data_type (t ):
546+ if isinstance (t , list ):
547+ return t [0 ]
548+ else :
549+ return t + " "
550+
520551 self .fn_name = function .__name__
521552 # Get the source code of the function
522553 source_code = getsource (function )
@@ -530,6 +561,148 @@ def __init__(self, function):
530561 # Write the Python source code to the file
531562 file .write ("@ccallable" )
532563 file .write (source_code )
564+ # ----------------------------------------------------------------------
565+ types = function .__annotations__
566+ self .arg_type_formats = ""
567+ self .return_type = ""
568+ self .return_type_format = ""
569+ self .arg_types = {}
570+ counter = 1
571+ for t in types .keys ():
572+ if t == "return" :
573+ type = get_type_info (types [t ])
574+ self .return_type_format = type [0 ]
575+ self .return_type = type [1 ]
576+ else :
577+ type = get_type_info (types [t ])
578+ self .arg_type_formats += type [0 ]
579+ self .arg_types [counter ] = type [1 ]
580+ counter += 1
581+ # ----------------------------------------------------------------------
582+ # `arg_0`: used as the return variables
583+ # arguments are declared as `arg_1`, `arg_2`, ...
584+ variables_decl = ""
585+ if self .return_type != "" :
586+ variables_decl = "// Declare return variables and arguments\n "
587+ variables_decl += " " + get_data_type (self .return_type ) + "arg_" \
588+ + str (0 ) + ";\n "
589+ # ----------------------------------------------------------------------
590+ # `PyArray_AsCArray` is used to convert NumPy Arrays to C Arrays
591+ # `fill_array_details` contains arrays operations to be
592+ # performed on the arguments
593+ # `parse_args` are used to capture the args from CPython
594+ # `pass_args` are the args that are passed to the shared library function
595+ fill_array_details = ""
596+ parse_args = ""
597+ pass_args = ""
598+ numpy_init = ""
599+ for i , t in self .arg_types .items ():
600+ if i > 1 :
601+ parse_args += ", "
602+ pass_args += ", "
603+ if isinstance (t , list ):
604+ if numpy_init == "" :
605+ numpy_init = "// Initialize NumPy\n import_array();\n \n "
606+ fill_array_details += f"""\n
607+ // fill array details for args[{ i - 1 } ]
608+ if (PyArray_NDIM(arg_{ i } ) != 1) {{
609+ PyErr_SetString(PyExc_TypeError,
610+ "Only 1 dimension is implemented for now.");
611+ return NULL;
612+ }}
613+
614+ { t [1 ]} s_array_{ i } = malloc(sizeof(struct r64));
615+ {{
616+ { t [2 ]} array;
617+ // Create C arrays from numpy objects:
618+ PyArray_Descr *descr = PyArray_DescrFromType(PyArray_TYPE(arg_{ i } ));
619+ npy_intp dims[1];
620+ if (PyArray_AsCArray((PyObject **)&arg_{ i } , (void *)&array, dims, 1, descr) < 0) {{
621+ PyErr_SetString(PyExc_TypeError, "error converting to c array");
622+ return NULL;
623+ }}
624+
625+ s_array_{ i } ->data = array;
626+ s_array_{ i } ->n_dims = 1;
627+ s_array_{ i } ->dims[0].lower_bound = 0;
628+ s_array_{ i } ->dims[0].length = dims[0];
629+ s_array_{ i } ->is_allocated = false;
630+ }}"""
631+ pass_args += "s_array_" + str (i )
632+ else :
633+ pass_args += "arg_" + str (i )
634+ variables_decl += " " + get_data_type (t ) + "arg_" + str (i ) + ";\n "
635+ parse_args += "&arg_" + str (i )
636+
637+ if parse_args != "" :
638+ parse_args = f"""\n // Parse the arguments from Python
639+ if (!PyArg_ParseTuple(args, "{ self .arg_type_formats } ", { parse_args } )) {{
640+ return NULL;
641+ }}"""
642+
643+ # ----------------------------------------------------------------------
644+ # Handle the return variable if any; otherwise, return None
645+ fill_return_details = ""
646+ if self .return_type != "" :
647+ fill_return_details = f"""\n \n // Call the C function
648+ arg_0 = { self .fn_name } ({ pass_args } );
649+
650+ // Build and return the result as a Python object
651+ return Py_BuildValue("{ self .return_type_format } ", arg_0);"""
652+ else :
653+ fill_return_details = f"""{ self .fn_name } ({ pass_args } );
654+ Py_RETURN_NONE;"""
655+
656+ # ----------------------------------------------------------------------
657+ # Python wrapper for the Shared library
658+ template = f"""// Python headers
659+ #include <Python.h>
660+
661+ // NumPy C/API headers
662+ #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION // remove warnings
663+ #include <numpy/ndarrayobject.h>
664+
665+ // LPython generated C code
666+ #include "a.h"
667+
668+ // Define the Python module and method mappings
669+ static PyObject* define_module(PyObject* self, PyObject* args) {{
670+ { numpy_init } { variables_decl } { parse_args } \
671+ { fill_array_details } { fill_return_details }
672+ }}
673+
674+ // Define the module's method table
675+ static PyMethodDef module_methods[] = {{
676+ {{"{ self .fn_name } ", define_module, METH_VARARGS,
677+ "Handle arguments & return variable and call the function"}},
678+ {{NULL, NULL, 0, NULL}}
679+ }};
680+
681+ // Define the module initialization function
682+ static struct PyModuleDef module_def = {{
683+ PyModuleDef_HEAD_INIT,
684+ "lpython_jit_module",
685+ "Shared library to use LPython generated functions",
686+ -1,
687+ module_methods
688+ }};
689+
690+ PyMODINIT_FUNC PyInit_lpython_jit_module(void) {{
691+ PyObject* module;
692+
693+ // Create the module object
694+ module = PyModule_Create(&module_def);
695+ if (!module) {{
696+ return NULL;
697+ }}
698+
699+ return module;
700+ }}
701+ """
702+ # ----------------------------------------------------------------------
703+ # Write the C source code to the file
704+ with open ("a.c" , "w" ) as file :
705+ file .write (template )
533706
534707 # ----------------------------------------------------------------------
535708 # TODO: Use LLVM instead of C backend
0 commit comments