Skip to content

Commit e642072

Browse files
authored
Merge pull request #2056 from Smit-create/i-1981-1
Store initializer exprs for Structs in ASR
2 parents 32273ce + bdbcee7 commit e642072

26 files changed

+334
-182
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,7 @@ RUN(NAME structs_28 LABELS cpython llvm c)
592592
RUN(NAME structs_29 LABELS cpython llvm)
593593
RUN(NAME structs_30 LABELS cpython llvm c)
594594
RUN(NAME structs_31 LABELS cpython llvm c)
595+
RUN(NAME structs_32 LABELS cpython llvm c)
595596

596597
RUN(NAME symbolics_01 LABELS cpython_sym c_sym)
597598
RUN(NAME symbolics_02 LABELS cpython_sym c_sym)

integration_tests/structs_10.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33

44
@dataclass
55
class Mat:
6-
mat: f64[2, 2]
6+
mat: f64[2, 2] = empty((2, 2), dtype=float64)
77

88
@dataclass
99
class Vec:
10-
vec: f64[2]
10+
vec: f64[2] = empty(2, dtype=float64)
1111

1212
@dataclass
1313
class MatVec:
14-
mat: Mat = Mat([f64(0.0), f64(0.0)])
15-
vec: Vec = Vec([f64(0.0), f64(0.0)])
14+
mat: Mat = Mat()
15+
vec: Vec = Vec()
1616

1717
def rotate(mat_vec: MatVec) -> f64[2]:
1818
rotated_vec: f64[2] = empty(2, dtype=float64)

integration_tests/structs_27.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from lpython import dataclass, i32
1+
from lpython import dataclass, i32, u16, f32
22

33

44
@dataclass
@@ -7,6 +7,15 @@ class StringIO:
77
_0cursor : i32 = 10
88
_len : i32 = 1
99

10+
@dataclass
11+
class StringIONew:
12+
_buf : str
13+
_0cursor : i32 = i32(142)
14+
_len : i32 = i32(2439)
15+
_var1 : u16 = u16(23)
16+
_var2 : f32 = f32(30.24)
17+
18+
#print("ok")
1019

1120
def test_issue_1928():
1221
integer_asr : str = '(Integer 4 [])'
@@ -47,4 +56,22 @@ def test_issue_1928():
4756
assert test_dude4._0cursor == 31
4857

4958

59+
def test_issue_1981():
60+
integer_asr : str = '(Integer 4 [])'
61+
test_dude : StringIONew = StringIONew(integer_asr)
62+
assert test_dude._buf == integer_asr
63+
assert test_dude._len == 2439
64+
assert test_dude._0cursor == 142
65+
assert test_dude._var1 == u16(23)
66+
assert abs(test_dude._var2 - f32(30.24)) < f32(1e-5)
67+
test_dude._len = 13
68+
test_dude._0cursor = 52
69+
test_dude._var1 = u16(34)
70+
assert test_dude._buf == integer_asr
71+
assert test_dude._len == 13
72+
assert test_dude._0cursor == 52
73+
assert test_dude._var1 == u16(34)
74+
75+
76+
test_issue_1981()
5077
test_issue_1928()

integration_tests/structs_32.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from lpython import packed, dataclass, i32, InOut
2+
3+
4+
@packed
5+
@dataclass
6+
class inner_struct:
7+
a: i32
8+
9+
10+
@packed
11+
@dataclass
12+
class outer_struct:
13+
b: inner_struct = inner_struct(0)
14+
15+
16+
def update_my_inner_struct(my_inner_struct: InOut[inner_struct]) -> None:
17+
my_inner_struct.a = 99999
18+
19+
20+
def update_my_outer_struct(my_outer_struct: InOut[outer_struct]) -> None:
21+
my_outer_struct.b.a = 12345
22+
23+
24+
def main() -> None:
25+
my_outer_struct: outer_struct = outer_struct()
26+
my_inner_struct: inner_struct = my_outer_struct.b
27+
28+
assert my_outer_struct.b.a == 0
29+
30+
my_outer_struct.b.a = 12345
31+
assert my_outer_struct.b.a == 12345
32+
33+
my_outer_struct.b.a = 0
34+
assert my_outer_struct.b.a == 0
35+
36+
update_my_outer_struct(my_outer_struct)
37+
assert my_outer_struct.b.a == 12345
38+
39+
my_inner_struct.a = 1111
40+
assert my_inner_struct.a == 1111
41+
42+
update_my_inner_struct(my_inner_struct)
43+
assert my_inner_struct.a == 99999
44+
45+
main()

src/libasr/ASR.asdl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@ symbol
9696
identifier original_name, access access)
9797
| StructType(symbol_table symtab, identifier name, identifier* dependencies,
9898
identifier* members, abi abi, access access, bool is_packed, bool is_abstract,
99-
expr? alignment, symbol? parent)
99+
call_arg* initializers, expr? alignment, symbol? parent)
100100
| EnumType(symbol_table symtab, identifier name, identifier* dependencies,
101101
identifier* members, abi abi, access access, enumtype enum_value_type,
102102
ttype type, symbol? parent)
103103
| UnionType(symbol_table symtab, identifier name, identifier* dependencies,
104-
identifier* members, abi abi, access access, symbol? parent)
104+
identifier* members, abi abi, access access, call_arg* initializers, symbol? parent)
105105
| Variable(symbol_table parent_symtab, identifier name, identifier* dependencies,
106106
intent intent, expr? symbolic_value, expr? value, storage_type storage,
107107
ttype type, symbol? type_declaration,

src/libasr/codegen/asr_to_c.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,12 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
146146
void allocate_array_members_of_struct(ASR::StructType_t* der_type_t, std::string& sub,
147147
std::string indent, std::string name) {
148148
for( auto itr: der_type_t->m_symtab->get_scope() ) {
149-
if( ASR::is_a<ASR::UnionType_t>(*itr.second) ||
150-
ASR::is_a<ASR::StructType_t>(*itr.second) ) {
149+
ASR::symbol_t *sym = ASRUtils::symbol_get_past_external(itr.second);
150+
if( ASR::is_a<ASR::UnionType_t>(*sym) ||
151+
ASR::is_a<ASR::StructType_t>(*sym) ) {
151152
continue ;
152153
}
153-
ASR::ttype_t* mem_type = ASRUtils::symbol_type(itr.second);
154+
ASR::ttype_t* mem_type = ASRUtils::symbol_type(sym);
154155
if( ASRUtils::is_character(*mem_type) ) {
155156
sub += indent + name + "->" + itr.first + " = NULL;\n";
156157
} else if( ASRUtils::is_array(mem_type) &&

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,6 +1305,11 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
13051305
last_expr_precedence = 2;
13061306
}
13071307

1308+
void visit_UnsignedIntegerConstant(const ASR::UnsignedIntegerConstant_t &x) {
1309+
src = std::to_string(x.m_n);
1310+
last_expr_precedence = 2;
1311+
}
1312+
13081313
void visit_RealConstant(const ASR::RealConstant_t &x) {
13091314
// TODO: remove extra spaces from the front of double_to_scientific result
13101315
src = double_to_scientific(x.m_r);

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
809809
llvm::Type* max_sized_type = nullptr;
810810
size_t max_type_size = 0;
811811
for( auto itr = scope.begin(); itr != scope.end(); itr++ ) {
812-
ASR::Variable_t* member = ASR::down_cast<ASR::Variable_t>(itr->second);
812+
ASR::Variable_t* member = ASR::down_cast<ASR::Variable_t>(ASRUtils::symbol_get_past_external(itr->second));
813813
llvm::Type* llvm_mem_type = getMemberType(member->m_type, member);
814814
size_t type_size = data_layout.getTypeAllocSize(llvm_mem_type);
815815
if( max_type_size < type_size ) {
@@ -3316,14 +3316,18 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
33163316
ASRUtils::symbol_get_past_external(struct_t->m_derived_type));
33173317
std::string struct_type_name = struct_type_t->m_name;
33183318
for( auto item: struct_type_t->m_symtab->get_scope() ) {
3319-
if( ASR::is_a<ASR::ClassProcedure_t>(*item.second) ||
3320-
ASR::is_a<ASR::GenericProcedure_t>(*item.second) ||
3321-
ASR::is_a<ASR::UnionType_t>(*item.second) ||
3322-
ASR::is_a<ASR::StructType_t>(*item.second) ||
3323-
ASR::is_a<ASR::CustomOperator_t>(*item.second) ) {
3319+
ASR::symbol_t *sym = ASRUtils::symbol_get_past_external(item.second);
3320+
if (name2memidx[struct_type_name].find(item.first) == name2memidx[struct_type_name].end()) {
3321+
continue;
3322+
}
3323+
if( ASR::is_a<ASR::ClassProcedure_t>(*sym) ||
3324+
ASR::is_a<ASR::GenericProcedure_t>(*sym) ||
3325+
ASR::is_a<ASR::UnionType_t>(*sym) ||
3326+
ASR::is_a<ASR::StructType_t>(*sym) ||
3327+
ASR::is_a<ASR::CustomOperator_t>(*sym) ) {
33243328
continue ;
33253329
}
3326-
ASR::ttype_t* symbol_type = ASRUtils::symbol_type(item.second);
3330+
ASR::ttype_t* symbol_type = ASRUtils::symbol_type(sym);
33273331
int idx = name2memidx[struct_type_name][item.first];
33283332
llvm::Value* ptr_member = llvm_utils->create_gep(ptr, idx);
33293333
ASR::Variable_t* v = nullptr;
@@ -8569,7 +8573,7 @@ Result<std::unique_ptr<LLVMModule>> asr_to_llvm(ASR::TranslationUnit_t &asr,
85698573
pass_manager.apply_passes(al, &asr, pass_options, diagnostics);
85708574

85718575
// Uncomment for debugging the ASR after the transformation
8572-
// std::cout << LFortran::pickle(asr, false, false, false) << std::endl;
8576+
// std::cout << LCompilers::LPython::pickle(asr, true, true, false) << std::endl;
85738577

85748578
try {
85758579
v.visit_asr((ASR::asr_t&)asr);

src/libasr/codegen/c_utils.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -843,26 +843,28 @@ class CCPPDSUtils {
843843
+ struct_type_str + "* dest)";
844844
func_decls += "inline " + signature + ";\n";
845845
generated_code += indent + signature + " {\n";
846-
for( auto item: struct_type_t->m_symtab->get_scope() ) {
847-
ASR::ttype_t* member_type_asr = ASRUtils::symbol_type(item.second);
846+
for(size_t i=0; i < struct_type_t->n_members; i++) {
847+
std::string mem_name = std::string(struct_type_t->m_members[i]);
848+
ASR::symbol_t* member = struct_type_t->m_symtab->get_symbol(mem_name);
849+
ASR::ttype_t* member_type_asr = ASRUtils::symbol_type(member);
848850
if( CUtils::is_non_primitive_DT(member_type_asr) ||
849851
ASR::is_a<ASR::Character_t>(*member_type_asr) ) {
850-
generated_code += indent + tab + get_deepcopy(member_type_asr, "&(src->" + item.first + ")",
851-
"&(dest->" + item.first + ")") + ";\n";
852+
generated_code += indent + tab + get_deepcopy(member_type_asr, "&(src->" + mem_name + ")",
853+
"&(dest->" + mem_name + ")") + ";\n";
852854
} else if( ASRUtils::is_array(member_type_asr) ) {
853855
ASR::dimension_t* m_dims = nullptr;
854856
size_t n_dims = ASRUtils::extract_dimensions_from_ttype(member_type_asr, m_dims);
855857
if( ASRUtils::is_fixed_size_array(m_dims, n_dims) ) {
856858
std::string array_size = std::to_string(ASRUtils::get_fixed_size_of_array(m_dims, n_dims));
857859
array_size += "*sizeof(" + CUtils::get_c_type_from_ttype_t(member_type_asr) + ")";
858-
generated_code += indent + tab + "memcpy(dest->" + item.first + ", src->" + item.first +
860+
generated_code += indent + tab + "memcpy(dest->" + mem_name + ", src->" + mem_name +
859861
", " + array_size + ");\n";
860862
} else {
861-
generated_code += indent + tab + get_deepcopy(member_type_asr, "src->" + item.first,
862-
"dest->" + item.first) + ";\n";
863+
generated_code += indent + tab + get_deepcopy(member_type_asr, "src->" + mem_name,
864+
"dest->" + mem_name) + ";\n";
863865
}
864866
} else {
865-
generated_code += indent + tab + "dest->" + item.first + " = " + " src->" + item.first + ";\n";
867+
generated_code += indent + tab + "dest->" + mem_name + " = " + " src->" + mem_name + ";\n";
866868
}
867869
}
868870
generated_code += indent + "}\n\n";

0 commit comments

Comments
 (0)