Skip to content

Commit 3d61ec0

Browse files
authored
Merge pull request #2175 from Smit-create/i-2169
C: Support callbacks
2 parents 126a3b5 + 8ba9d30 commit 3d61ec0

30 files changed

+141
-35
lines changed

integration_tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,8 @@ RUN(NAME global_syms_04 LABELS cpython llvm c wasm wasm_x64)
695695
RUN(NAME global_syms_05 LABELS cpython llvm c)
696696
RUN(NAME global_syms_06 LABELS cpython llvm c)
697697

698-
RUN(NAME callback_01 LABELS cpython llvm)
698+
RUN(NAME callback_01 LABELS cpython llvm c)
699+
RUN(NAME callback_02 LABELS cpython llvm c)
699700

700701
# Intrinsic Functions
701702
RUN(NAME intrinsics_01 LABELS cpython llvm NOFAST) # any

integration_tests/callback_02.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from lpython import ccallback, i32, Callable
2+
3+
# test issue 2169
4+
5+
def foo(x : i32) -> i32:
6+
return x**2
7+
8+
def bar(func : Callable[[i32], i32], arg : i32) -> i32:
9+
return func(arg)
10+
11+
@ccallback
12+
def entry_point() -> None:
13+
z: i32 = 5
14+
x: i32 = bar(foo, z)
15+
assert z**2 == x
16+
17+
18+
entry_point()

src/libasr/codegen/asr_to_c.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,8 +739,8 @@ R"(
739739
array_types_decls.insert(0, "struct dimension_descriptor\n"
740740
"{\n int32_t lower_bound, length;\n};\n");
741741
}
742-
743-
src = to_include + head + array_types_decls + unit_src +
742+
forward_decl_functions += "\n\n";
743+
src = to_include + head + array_types_decls + forward_decl_functions + unit_src +
744744
ds_funcs_defined + util_funcs_defined;
745745
if (!emit_headers.empty()) {
746746
std::string to_includes_1 = "";

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
137137
std::map<int32_t, std::string> gotoid2name;
138138
std::map<std::string, std::string> emit_headers;
139139
std::string array_types_decls;
140+
std::string forward_decl_functions;
140141

141142
// Output configuration:
142143
// Use std::string or char*
@@ -493,18 +494,18 @@ R"(#include <stdio.h>
493494
}
494495

495496
// Returns the declaration, no semi colon at the end
496-
std::string get_function_declaration(const ASR::Function_t &x, bool &has_typevar) {
497+
std::string get_function_declaration(const ASR::Function_t &x, bool &has_typevar, bool is_pointer=false) {
497498
template_for_Kokkos.clear();
498499
template_number = 0;
499500
std::string sub, inl, static_attr;
500501

501502
// This helps to check if the function is generic.
502503
// If it is generic we skip the codegen for that function.
503504
has_typevar = false;
504-
if (ASRUtils::get_FunctionType(x)->m_inline) {
505+
if (ASRUtils::get_FunctionType(x)->m_inline && !is_pointer) {
505506
inl = "inline __attribute__((always_inline)) ";
506507
}
507-
if( ASRUtils::get_FunctionType(x)->m_static ) {
508+
if( ASRUtils::get_FunctionType(x)->m_static && !is_pointer) {
508509
static_attr = "static ";
509510
}
510511
if (x.m_return_var) {
@@ -526,30 +527,47 @@ R"(#include <stdio.h>
526527
f_type->m_deftype == ASR::deftypeType::Implementation) {
527528
sym_name = "_xx_internal_" + sym_name + "_xx";
528529
}
529-
std::string func = static_attr + inl + sub + sym_name + "(";
530+
std::string func = static_attr + inl + sub;
531+
if (is_pointer) {
532+
func += "(*" + sym_name + ")(";
533+
} else {
534+
func += sym_name + "(";
535+
}
530536
bracket_open++;
531537
for (size_t i=0; i<x.n_args; i++) {
532-
ASR::Variable_t *arg = ASRUtils::EXPR2VAR(x.m_args[i]);
533-
LCOMPILERS_ASSERT(ASRUtils::is_arg_dummy(arg->m_intent));
534-
if (ASR::is_a<ASR::TypeParameter_t>(*arg->m_type)) {
535-
has_typevar = true;
536-
bracket_open--;
537-
return "";
538-
}
539-
if( is_c ) {
540-
CDeclarationOptions c_decl_options;
541-
c_decl_options.pre_initialise_derived_type = false;
542-
func += self().convert_variable_decl(*arg, &c_decl_options);
538+
ASR::symbol_t *sym = ASRUtils::symbol_get_past_external(
539+
ASR::down_cast<ASR::Var_t>(x.m_args[i])->m_v);
540+
if (ASR::is_a<ASR::Variable_t>(*sym)) {
541+
ASR::Variable_t *arg = ASR::down_cast<ASR::Variable_t>(sym);
542+
LCOMPILERS_ASSERT(ASRUtils::is_arg_dummy(arg->m_intent));
543+
if( is_c ) {
544+
CDeclarationOptions c_decl_options;
545+
c_decl_options.pre_initialise_derived_type = false;
546+
func += self().convert_variable_decl(*arg, &c_decl_options);
547+
} else {
548+
CPPDeclarationOptions cpp_decl_options;
549+
cpp_decl_options.use_static = false;
550+
cpp_decl_options.use_templates_for_arrays = true;
551+
func += self().convert_variable_decl(*arg, &cpp_decl_options);
552+
}
553+
if (ASR::is_a<ASR::TypeParameter_t>(*arg->m_type)) {
554+
has_typevar = true;
555+
bracket_open--;
556+
return "";
557+
}
558+
} else if (ASR::is_a<ASR::Function_t>(*sym)) {
559+
ASR::Function_t *fun = ASR::down_cast<ASR::Function_t>(sym);
560+
func += get_function_declaration(*fun, has_typevar, true);
543561
} else {
544-
CPPDeclarationOptions cpp_decl_options;
545-
cpp_decl_options.use_static = false;
546-
cpp_decl_options.use_templates_for_arrays = true;
547-
func += self().convert_variable_decl(*arg, &cpp_decl_options);
562+
throw CodeGenError("Unsupported function argument");
548563
}
549564
if (i < x.n_args-1) func += ", ";
550565
}
551566
func += ")";
552567
bracket_open--;
568+
if (f_type->m_abi == ASR::abiType::Source) {
569+
forward_decl_functions += func + ";\n";
570+
}
553571
if( is_c || template_for_Kokkos.empty() ) {
554572
return func;
555573
}
@@ -1735,6 +1753,10 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
17351753

17361754
void visit_Var(const ASR::Var_t &x) {
17371755
const ASR::symbol_t *s = ASRUtils::symbol_get_past_external(x.m_v);
1756+
if (ASR::is_a<ASR::Function_t>(*s)) {
1757+
src = ASRUtils::symbol_name(s);
1758+
return;
1759+
}
17381760
ASR::Variable_t* sv = ASR::down_cast<ASR::Variable_t>(s);
17391761
if( (sv->m_intent == ASRUtils::intent_in ||
17401762
sv->m_intent == ASRUtils::intent_inout) &&

tests/reference/c-c_interop1-e215531.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "c-c_interop1-e215531.stdout",
9-
"stdout_hash": "31d7d540adfef1263741e1b9709e4a2a73c6208147d735906dbf233e",
9+
"stdout_hash": "9d9dd1715b625024203f8b4de330a936a806c5e8031ae7b643c23e71",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

tests/reference/c-c_interop1-e215531.stdout

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <lfortran_intrinsics.h>
88

99

10+
11+
1012
// Implementations
1113
double f(double x);
1214

tests/reference/c-expr7-bb2692a.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "c-expr7-bb2692a.stdout",
9-
"stdout_hash": "be3bac9bcf7f1fab11741c22151820f60793d5afc670486e9d0913aa",
9+
"stdout_hash": "11026dff0f543f773698075edb3d690f8a475e8639124f1e1dcaef6a",
1010
"stderr": "c-expr7-bb2692a.stderr",
1111
"stderr_hash": "6e9790ac88db1a9ead8f64a91ba8a6605de67167037908a74b77be0c",
1212
"returncode": 0

tests/reference/c-expr7-bb2692a.stdout

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
#include <string.h>
99
#include <lfortran_intrinsics.h>
1010

11+
void test_pow();
12+
int32_t test_pow_1(int32_t a, int32_t b);
13+
void main0();
14+
void __main____global_statements();
15+
16+
1117

1218
// Implementations
1319
double __lpython_overloaded_0__pow(int32_t x, int32_t y)

tests/reference/c-expr_01-28f449f.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "c-expr_01-28f449f.stdout",
9-
"stdout_hash": "21755da8aacbb21ee5a5229c8f47f7ba40c36d942af84f848e36bd5b",
9+
"stdout_hash": "b9e6df09a73026d76ea9eefed3e2b211476b1ac55956ee8150a4a417",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

tests/reference/c-expr_01-28f449f.stdout

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
#include <string.h>
77
#include <lfortran_intrinsics.h>
88

9+
inline __attribute__((always_inline)) int32_t add(int32_t x, int32_t y);
10+
inline __attribute__((always_inline)) int32_t and_op(int32_t x, int32_t y);
11+
void main0();
12+
void __main____global_statements();
13+
14+
915

1016
// Implementations
1117
inline __attribute__((always_inline)) int32_t add(int32_t x, int32_t y)

0 commit comments

Comments
 (0)