Skip to content

Commit 61a27d5

Browse files
authored
[ASR Pass] Replace the Class obj Assignment with SubroutineCall (#2795)
Move the code modification to the ASR Pass
1 parent f3389e6 commit 61a27d5

File tree

3 files changed

+79
-42
lines changed

3 files changed

+79
-42
lines changed

src/libasr/pass/pass_utils.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,40 @@ namespace LCompilers {
12471247
}
12481248
}
12491249
}
1250+
ASR::symbol_t* get_struct_member(Allocator& al, ASR::symbol_t* struct_type_sym, std::string &call_name,
1251+
const Location &loc, SymbolTable* current_scope) {
1252+
ASR::Struct_t* struct_type = ASR::down_cast<ASR::Struct_t>(struct_type_sym);
1253+
std::string struct_var_name = struct_type->m_name;
1254+
std::string struct_member_name = call_name;
1255+
ASR::symbol_t* struct_member = struct_type->m_symtab->resolve_symbol(struct_member_name);
1256+
ASR::symbol_t* struct_mem_asr_owner = ASRUtils::get_asr_owner(struct_member);
1257+
if( !struct_member || !struct_mem_asr_owner ||
1258+
!ASR::is_a<ASR::Struct_t>(*struct_mem_asr_owner) ) {
1259+
throw LCompilersException(struct_member_name + " not present in " +
1260+
struct_var_name + " dataclass");
1261+
}
1262+
std::string import_name = struct_var_name + "_" + struct_member_name;
1263+
ASR::symbol_t* import_struct_member = current_scope->resolve_symbol(import_name);
1264+
bool import_from_struct = true;
1265+
if( import_struct_member ) {
1266+
if( ASR::is_a<ASR::ExternalSymbol_t>(*import_struct_member) ) {
1267+
ASR::ExternalSymbol_t* ext_sym = ASR::down_cast<ASR::ExternalSymbol_t>(import_struct_member);
1268+
if( ext_sym->m_external == struct_member &&
1269+
std::string(ext_sym->m_module_name) == struct_var_name ) {
1270+
import_from_struct = false;
1271+
}
1272+
}
1273+
}
1274+
if( import_from_struct ) {
1275+
import_name = current_scope->get_unique_name(import_name, false);
1276+
import_struct_member = ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(al,
1277+
loc, current_scope, s2c(al, import_name),
1278+
struct_member, s2c(al, struct_var_name), nullptr, 0,
1279+
s2c(al, struct_member_name), ASR::accessType::Public));
1280+
current_scope->add_symbol(import_name, import_struct_member);
1281+
}
1282+
return import_struct_member;
1283+
}
12501284

12511285
} // namespace PassUtils
12521286

src/libasr/pass/pass_utils.h

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,9 @@ namespace LCompilers {
570570
*/
571571
};
572572

573+
ASR::symbol_t* get_struct_member(Allocator& al, ASR::symbol_t* struct_type_sym, std::string &call_name,
574+
const Location &loc, SymbolTable* current_scope);
575+
573576
namespace ReplacerUtils {
574577
template <typename T>
575578
void replace_StructConstructor(ASR::StructConstructor_t* x,
@@ -578,6 +581,33 @@ namespace LCompilers {
578581
bool perform_cast=false,
579582
ASR::cast_kindType cast_kind=ASR::cast_kindType::IntegerToInteger,
580583
ASR::ttype_t* casted_type=nullptr) {
584+
if ( ASR::is_a<ASR::Struct_t>(*(x->m_dt_sym)) ) {
585+
ASR::Struct_t* st = ASR::down_cast<ASR::Struct_t>(x->m_dt_sym);
586+
if ( st->n_member_functions > 0 ) {
587+
remove_original_statement = true;
588+
if ( !ASR::is_a<ASR::Var_t>(*(replacer->result_var)) ) {
589+
throw LCompilersException("Expected a var here");
590+
}
591+
ASR::Var_t* target = ASR::down_cast<ASR::Var_t>(replacer->result_var);
592+
ASR::call_arg_t first_arg;
593+
first_arg.loc = x->base.base.loc; first_arg.m_value = replacer->result_var;
594+
Vec<ASR::call_arg_t> new_args; new_args.reserve(replacer->al,x->n_args+1);
595+
new_args.push_back(replacer->al, first_arg);
596+
for( size_t i = 0; i < x->n_args; i++ ) {
597+
new_args.push_back(replacer->al, x->m_args[i]);
598+
}
599+
ASR::StructType_t* type = ASR::down_cast<ASR::StructType_t>(
600+
(ASR::down_cast<ASR::Variable_t>(target->m_v))->m_type);
601+
std::string call_name = "__init__";
602+
ASR::symbol_t* call_sym = get_struct_member(replacer->al,type->m_derived_type, call_name,
603+
x->base.base.loc, replacer->current_scope);
604+
result_vec->push_back(replacer->al, ASRUtils::STMT(
605+
ASRUtils::make_SubroutineCall_t_util(replacer->al,
606+
x->base.base.loc, call_sym, nullptr, new_args.p, new_args.size(),
607+
nullptr, nullptr, false, false)));
608+
return;
609+
}
610+
}
581611
if( x->n_args == 0 ) {
582612
if( !inside_symtab ) {
583613
remove_original_statement = true;
@@ -598,22 +628,22 @@ namespace LCompilers {
598628
}
599629

600630
std::deque<ASR::symbol_t*> constructor_arg_syms;
601-
ASR::StructType_t* dt_der = ASR::down_cast<ASR::StructType_t>(x->m_type);
602-
ASR::Struct_t* dt_dertype = ASR::down_cast<ASR::Struct_t>(
603-
ASRUtils::symbol_get_past_external(dt_der->m_derived_type));
604-
while( dt_dertype ) {
605-
for( int i = (int) dt_dertype->n_members - 1; i >= 0; i-- ) {
631+
ASR::StructType_t* dt_dertype = ASR::down_cast<ASR::StructType_t>(x->m_type);
632+
ASR::Struct_t* dt_der = ASR::down_cast<ASR::Struct_t>(
633+
ASRUtils::symbol_get_past_external(dt_dertype->m_derived_type));
634+
while( dt_der ) {
635+
for( int i = (int) dt_der->n_members - 1; i >= 0; i-- ) {
606636
constructor_arg_syms.push_front(
607-
dt_dertype->m_symtab->get_symbol(
608-
dt_dertype->m_members[i]));
637+
dt_der->m_symtab->get_symbol(
638+
dt_der->m_members[i]));
609639
}
610-
if( dt_dertype->m_parent != nullptr ) {
640+
if( dt_der->m_parent != nullptr ) {
611641
ASR::symbol_t* dt_der_sym = ASRUtils::symbol_get_past_external(
612-
dt_dertype->m_parent);
642+
dt_der->m_parent);
613643
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*dt_der_sym));
614-
dt_dertype = ASR::down_cast<ASR::Struct_t>(dt_der_sym);
644+
dt_der = ASR::down_cast<ASR::Struct_t>(dt_der_sym);
615645
} else {
616-
dt_dertype = nullptr;
646+
dt_der = nullptr;
617647
}
618648
}
619649
LCOMPILERS_ASSERT(constructor_arg_syms.size() == x->n_args);

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,16 +1282,13 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
12821282
args, st, loc);
12831283
}
12841284
if ( st->n_member_functions > 0 ) {
1285-
// Empty struct constructor
12861285
// Initializers handled in init proc call
1287-
Vec<ASR::call_arg_t>empty_args;
1288-
empty_args.reserve(al, 1);
1289-
for (size_t i = 0; i < st->n_members; i++) {
1290-
empty_args.push_back(al, st->m_initializers[i]);
1286+
if ( n_kwargs>0 ) {
1287+
throw SemanticError("Keyword args are not supported", loc);
12911288
}
12921289
ASR::ttype_t* der_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc, stemp));
1293-
return ASR::make_StructConstructor_t(al, loc, stemp, empty_args.p,
1294-
empty_args.size(), der_type, nullptr);
1290+
return ASR::make_StructConstructor_t(al, loc, stemp, args.p,
1291+
args.size(), der_type, nullptr);
12951292
}
12961293

12971294
if ( args.size() > 0 && args.size() > st->n_members ) {
@@ -5316,17 +5313,6 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
53165313
if ( call->n_keywords>0 ) {
53175314
throw SemanticError("Kwargs not implemented yet", x.base.base.loc);
53185315
}
5319-
Vec<ASR::call_arg_t> args;
5320-
args.reserve(al, call->n_args + 1);
5321-
ASR::call_arg_t self_arg;
5322-
self_arg.loc = x.base.base.loc;
5323-
self_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, sym));
5324-
args.push_back(al, self_arg);
5325-
visit_expr_list(call->m_args, call->n_args, args);
5326-
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>((var->m_type))->m_derived_type;
5327-
std::string call_name = "__init__";
5328-
ASR::symbol_t* call_sym = get_struct_member(der, call_name, x.base.base.loc);
5329-
tmp = make_call_helper(al, call_sym, current_scope, args, call_name, x.base.base.loc);
53305316
}
53315317
}
53325318
}
@@ -5611,23 +5597,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
56115597
overloaded));
56125598
if ( target->type == ASR::exprType::Var &&
56135599
tmp_value->type == ASR::exprType::StructConstructor ) {
5614-
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, 1);
5615-
ASR::call_arg_t self_arg;
5616-
self_arg.loc = x.base.base.loc;
5617-
ASR::symbol_t* st = ASR::down_cast<ASR::Var_t>(target)->m_v;
5618-
self_arg.m_value = target;
5619-
new_args.push_back(al,self_arg);
56205600
AST::Call_t* call = AST::down_cast<AST::Call_t>(x.m_value);
56215601
if ( call->n_keywords>0 ) {
56225602
throw SemanticError("Kwargs not implemented yet", x.base.base.loc);
56235603
}
5624-
visit_expr_list(call->m_args, call->n_args, new_args);
5625-
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>(
5626-
ASR::down_cast<ASR::Variable_t>(st)->m_type)->m_derived_type;
5627-
std::string call_name = "__init__";
5628-
ASR::symbol_t* call_sym = get_struct_member(der, call_name, x.base.base.loc);
5629-
tmp_vec.push_back(make_call_helper(al, call_sym,
5630-
current_scope, new_args, call_name, x.base.base.loc));
56315604
}
56325605
}
56335606
// to make sure that we add only those statements in tmp_vec

0 commit comments

Comments
 (0)