@@ -784,9 +784,26 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
784784 ASR::call_arg_t c_arg;
785785 c_arg.loc = args[i].loc ;
786786 c_arg.m_value = args[i].m_value ;
787- cast_helper (m_args[i], c_arg.m_value , true );
788787 ASR::ttype_t * left_type = ASRUtils::expr_type (m_args[i]);
789788 ASR::ttype_t * right_type = ASRUtils::expr_type (c_arg.m_value );
789+ if ( ASR::is_a<ASR::StructType_t>(*left_type) && ASR::is_a<ASR::StructType_t>(*right_type) ) {
790+ ASR::StructType_t *l_type = ASR::down_cast<ASR::StructType_t>(left_type);
791+ ASR::StructType_t *r_type = ASR::down_cast<ASR::StructType_t>(right_type);
792+ ASR::Struct_t *l2_type = ASR::down_cast<ASR::Struct_t>(
793+ ASRUtils::symbol_get_past_external (
794+ l_type->m_derived_type ));
795+ ASR::Struct_t *r2_type = ASR::down_cast<ASR::Struct_t>(
796+ ASRUtils::symbol_get_past_external (
797+ r_type->m_derived_type ));
798+ if ( ASRUtils::is_derived_type_similar (l2_type, r2_type) ) {
799+ cast_helper (m_args[i], c_arg.m_value , true , true );
800+ check_type_equality = false ;
801+ } else {
802+ cast_helper (m_args[i], c_arg.m_value , true );
803+ }
804+ } else {
805+ cast_helper (m_args[i], c_arg.m_value , true );
806+ }
790807 if ( check_type_equality && !ASRUtils::check_equal_type (left_type, right_type) ) {
791808 std::string ltype = ASRUtils::type_to_str_python (left_type);
792809 std::string rtype = ASRUtils::type_to_str_python (right_type);
@@ -2962,9 +2979,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
29622979 std::string obj_name = x.m_args .m_args ->m_arg ;
29632980 for (size_t i = 0 ; i < x.n_body ; i++) {
29642981 std::string var_name;
2965- if (! AST::is_a<AST::AnnAssign_t>(*x.m_body [i]) ){
2966- throw SemanticError (" Only AnnAssign implemented in __init__ " ,
2967- x.m_body [i]->base .loc );
2982+ if ( !AST::is_a<AST::AnnAssign_t>(*x.m_body [i]) ){
2983+ continue ;
29682984 }
29692985 AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body [i]);
29702986 if (AST::is_a<AST::Attribute_t>(*ann_assign.m_target )){
@@ -3301,10 +3317,21 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
33013317 current_scope->add_symbol (x_m_name, class_type);
33023318 }
33033319 } else {
3304- if ( x.n_bases > 0 ) {
3305- throw SemanticError (" Inheritance in classes isn't supported yet." ,
3320+ ASR::symbol_t * parent = nullptr ;
3321+ if ( x.n_bases > 1 ) {
3322+ throw SemanticError (" Multiple inheritance in classes isn't supported yet." ,
33063323 x.base .base .loc );
33073324 }
3325+ else if (x.n_bases == 1 ) {
3326+ std::string b_name = " " ;
3327+ if ( AST::is_a<AST::Name_t>(*x.m_bases [0 ]) ) {
3328+ b_name = AST::down_cast<AST::Name_t>(x.m_bases [0 ])->m_id ;
3329+ } else {
3330+ throw SemanticError (" Expected a Name here" , x.base .base .loc );
3331+ }
3332+ parent = current_scope->resolve_symbol (b_name);
3333+ LCOMPILERS_ASSERT (ASR::is_a<ASR::Struct_t>(*parent));
3334+ }
33083335 SymbolTable *parent_scope = current_scope;
33093336 if ( ASR::symbol_t * sym = current_scope->resolve_symbol (x_m_name) ) {
33103337 LCOMPILERS_ASSERT (ASR::is_a<ASR::Struct_t>(*sym));
@@ -3316,7 +3343,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
33163343 f = AST::down_cast<AST::FunctionDef_t>(x.m_body [i]);
33173344 init_self_type (*f, sym, x.base .base .loc );
33183345 if ( std::string (f->m_name ) == std::string (" __init__" ) ) {
3319- this ->visit_init_body (*f);
3346+ this ->visit_init_body (*f, st-> m_parent , x. m_body [i]-> base . loc );
33203347 } else {
33213348 this ->visit_stmt (*x.m_body [i]);
33223349 }
@@ -3344,7 +3371,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
33443371 member_names.p , member_names.size (), member_fn_names.p ,
33453372 member_fn_names.size (), class_abi, ASR::accessType::Public,
33463373 false , false , member_init.p , member_init.size (),
3347- nullptr , nullptr ));
3374+ nullptr , parent ));
33483375 parent_scope->add_symbol (x.m_name , class_sym);
33493376 visit_ClassMembers (x, member_names, member_fn_names,
33503377 struct_dependencies, member_init, false , class_abi, true );
@@ -3387,7 +3414,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
33873414 current_scope = parent_scope;
33883415 }
33893416
3390- virtual void visit_init_body (const AST::FunctionDef_t &/* x*/ ) = 0;
3417+ virtual void visit_init_body (const AST::FunctionDef_t &/* x*/ , ASR:: symbol_t * /* parent_sym */ , const Location /* loc */ ) = 0;
33913418
33923419 void add_name (const Location &loc) {
33933420 std::string var_name = " __name__" ;
@@ -4421,7 +4448,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
44214448 // Implement visit_Global for Symbol Table visitor.
44224449 void visit_Global (const AST::Global_t &/* x*/ ) {}
44234450
4424- void visit_init_body (const AST::FunctionDef_t &/* x*/ ) {
4451+ void visit_init_body (const AST::FunctionDef_t &/* x*/ , ASR:: symbol_t * /* parent_sym */ , const Location /* loc */ ) {
44254452 // Implemented in BodyVisitor
44264453 }
44274454
@@ -5153,7 +5180,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
51535180 tmp = asr;
51545181 }
51555182
5156- void visit_init_body (const AST::FunctionDef_t &x) {
5183+ void visit_init_body (const AST::FunctionDef_t &x, ASR:: symbol_t * parent_sym, const Location loc ) {
51575184 SymbolTable *old_scope = current_scope;
51585185 ASR::symbol_t *t = current_scope->get_symbol (" __init__" );
51595186 if ( t==nullptr ) {
@@ -5163,31 +5190,77 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
51635190 throw SemanticError (" __init__ is not a function" , x.base .base .loc );
51645191 }
51655192 ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
5193+ current_scope = f->m_symtab ;
51665194 // Transform statements into correct format
5167- Vec<AST::stmt_t *> new_body;
5168- new_body.reserve (al, 1 );
5195+ Vec<AST::stmt_t *> body;
5196+ body.reserve (al, 1 );
5197+ ASR::stmt_t * super_call_stmt = nullptr ;
51695198 for (size_t i=0 ; i<x.n_body ; i++) {
5170- AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body [i]);
5171- if ( ann_assign.m_value != nullptr ) {
5172- Vec<AST::expr_t *>target;
5173- target.reserve (al, 1 );
5174- target.push_back (al, ann_assign.m_target );
5175- AST::ast_t * assgn_ast = AST::make_Assign_t (al, ann_assign.base .base .loc ,
5176- target.p , 1 , ann_assign.m_value , nullptr );
5177- AST::stmt_t * assgn = AST::down_cast<AST::stmt_t >(assgn_ast);
5178- new_body.push_back (al, assgn);
5199+ if (AST::is_a<AST::AnnAssign_t>(*x.m_body [i])) {
5200+ AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body [i]);
5201+ if ( ann_assign.m_value != nullptr ) {
5202+ Vec<AST::expr_t *>target;
5203+ target.reserve (al, 1 );
5204+ target.push_back (al, ann_assign.m_target );
5205+ AST::ast_t * assgn_ast = AST::make_Assign_t (al, ann_assign.base .base .loc ,
5206+ target.p , 1 , ann_assign.m_value , nullptr );
5207+ AST::stmt_t * assgn = AST::down_cast<AST::stmt_t >(assgn_ast);
5208+ body.push_back (al, assgn);
5209+ }
5210+ } else if (AST::is_a<AST::Expr_t>(*x.m_body [i]) &&
5211+ AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Expr_t>(x.m_body [i])->m_value ))) {
5212+ AST::Call_t* c = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Expr_t>(x.m_body [i])->m_value );
5213+
5214+ if ( !AST::is_a<AST::Attribute_t>(*(c->m_func ))
5215+ || !AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Attribute_t>(c->m_func )->m_value )) ) {
5216+ body.push_back (al, x.m_body [i]);
5217+ continue ;
5218+ }
5219+ AST::Call_t* super_call = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Attribute_t>(c->m_func )->m_value );
5220+ std::string attr = AST::down_cast<AST::Attribute_t>(c->m_func )->m_attr ;
5221+ if ( AST::is_a<AST::Name_t>(*(super_call->m_func )) &&
5222+ std::string (AST::down_cast<AST::Name_t>(super_call->m_func )->m_id )==" super" &&
5223+ attr == " __init__" ) {
5224+ if (parent_sym == nullptr ) {
5225+ throw SemanticError (" The class doesn't have a base class" ,loc);
5226+ }
5227+ Vec<ASR::call_arg_t > args;
5228+ args.reserve (al, 1 );
5229+ parse_args (*super_call,args);
5230+ ASR::call_arg_t first_arg;
5231+ first_arg.loc = loc;
5232+ ASR::symbol_t * self_sym = current_scope->get_symbol (" self" );
5233+ first_arg.m_value = ASRUtils::EXPR (ASR::make_Var_t (al,loc,self_sym));
5234+ ASR::ttype_t * target_type = ASRUtils::TYPE (ASRUtils::make_StructType_t_util (al,loc,parent_sym));
5235+ cast_helper (target_type, first_arg.m_value , x.base .base .loc , true );
5236+ Vec<ASR::call_arg_t > args_w_first; args_w_first.reserve (al,1 );
5237+ args_w_first.push_back (al, first_arg);
5238+ for ( size_t i = 0 ; i < args.size (); i++ ) {
5239+ args_w_first.push_back (al,args[i]);
5240+ }
5241+ std::string call_name = " __init__" ;
5242+ ASR::symbol_t * call_sym = get_struct_member (parent_sym,call_name,loc);
5243+ super_call_stmt = ASRUtils::STMT (
5244+ ASR::make_SubroutineCall_t (al, loc, call_sym, call_sym, args_w_first.p ,
5245+ args_w_first.size (), nullptr ));
5246+ }
5247+ } else {
5248+ body.push_back (al, x.m_body [i]);
51795249 }
51805250 }
51815251 current_scope = f->m_symtab ;
5182- Vec<ASR::stmt_t *> body;
5183- body.reserve (al, x.n_body );
5252+ Vec<ASR::stmt_t *> body_asr;
5253+ body_asr.reserve (al, x.n_body );
5254+ if ( super_call_stmt ) {
5255+ body_asr.push_back (al, super_call_stmt);
5256+ }
51845257 Vec<ASR::symbol_t *> rts;
51855258 rts.reserve (al, 4 );
51865259 dependencies.clear (al);
5187- transform_stmts (body, new_body .n , new_body .p );
5260+ transform_stmts (body_asr, body .n , body .p );
51885261 for (const auto &rt: rt_vec) { rts.push_back (al, rt); }
5189- f->m_body = body .p ;
5190- f->n_body = body .size ();
5262+ f->m_body = body_asr .p ;
5263+ f->n_body = body_asr .size ();
51915264 ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(
51925265 f->m_function_signature );
51935266 func_type->m_restrictions = rts.p ;
@@ -6239,10 +6312,14 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
62396312 for ( size_t i = 0 ; i < der_type->n_members && !member_found; i++ ) {
62406313 member_found = std::string (der_type->m_members [i]) == member_name;
62416314 }
6242- if ( !member_found ) {
6315+ if ( !member_found && !der_type-> m_parent ) {
62436316 throw SemanticError (" No member " + member_name +
62446317 " found in " + std::string (der_type->m_name ),
62456318 loc);
6319+ } else if ( !member_found && der_type->m_parent ) {
6320+ ASR::ttype_t * parent_type = ASRUtils::TYPE (ASRUtils::make_StructType_t_util (al, loc,der_type->m_parent ));
6321+ visit_AttributeUtil (parent_type,attr_char,t,loc);
6322+ return ;
62466323 }
62476324 ASR::expr_t *val = ASR::down_cast<ASR::expr_t >(ASR::make_Var_t (al, loc, t));
62486325 ASR::symbol_t * member_sym = der_type->m_symtab ->resolve_symbol (member_name);
@@ -8064,7 +8141,8 @@ we will have to use something else.
80648141 // TODO: Correct Class and ClassType
80658142 // call to struct member function
80668143 // modifying args to pass the object as self
8067- ASR::symbol_t * der = ASR::down_cast<ASR::StructType_t>(var->m_type )->m_derived_type ;
8144+ ASR::symbol_t * der_sym = ASR::down_cast<ASR::StructType_t>(var->m_type )->m_derived_type ;
8145+ ASR::Struct_t* der = ASR::down_cast<ASR::Struct_t>(der_sym);
80688146 Vec<ASR::call_arg_t > new_args; new_args.reserve (al, args.n + 1 );
80698147 ASR::call_arg_t self_arg;
80708148 self_arg.loc = args[0 ].loc ;
@@ -8073,7 +8151,20 @@ we will have to use something else.
80738151 for (size_t i=0 ; i<args.n ; i++) {
80748152 new_args.push_back (al, args[i]);
80758153 }
8076- st = get_struct_member (der, call_name, loc);
8154+ if ( der->m_symtab ->get_symbol (call_name) ) {
8155+ st = get_struct_member (der_sym, call_name, loc);
8156+ } else if ( der->m_parent ) {
8157+ ASR::Struct_t* parent = ASR::down_cast<ASR::Struct_t>(der->m_parent );
8158+ if ( !parent->m_symtab ->get_symbol (call_name) ) {
8159+ throw SemanticError (" Method not found in the class " + std::string (der->m_name ) +
8160+ " or it's parents" ,loc);
8161+ } else {
8162+ st = get_struct_member (der->m_parent , call_name, loc);
8163+ }
8164+ } else {
8165+ throw SemanticError (" Method not found in the class " +std::string (der->m_name )+
8166+ " or it's parents" ,loc);
8167+ }
80778168 tmp = make_call_helper (al, st, current_scope, new_args, call_name, loc);
80788169 return ;
80798170 } else {
0 commit comments