Skip to content

Commit 78abfa6

Browse files
Fix: Clean reset and correct alias patch
1 parent b4e6589 commit 78abfa6

File tree

1 file changed

+129
-56
lines changed

1 file changed

+129
-56
lines changed

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 129 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4872,49 +4872,50 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
48724872
}
48734873
}
48744874

4875-
// --- 1. RESTORED & FIXED: visit_ImportFrom ---
48764875
void visit_ImportFrom(const AST::ImportFrom_t &x) {
48774876
if (!x.m_module) {
48784877
throw SemanticError("Not implemented: The import statement must currently specify the module name", x.base.base.loc);
48794878
}
48804879
std::string msym = x.m_module; // Module name
48814880

48824881
// Get the module, for now assuming it is not loaded, so we load it:
4883-
ASR::symbol_t *t = nullptr;
4884-
std::string rl_path = get_runtime_library_dir();
4885-
std::vector<std::string> paths;
4886-
for (auto &path:import_paths) {
4887-
paths.push_back(path);
4888-
}
4889-
paths.push_back(rl_path);
4890-
paths.push_back(parent_dir);
4891-
4892-
// FIX: Initialize bools to false to avoid undefined behavior
4893-
bool lpython = false, enum_py = false, copy = false, sympy = false;
4894-
4895-
set_module_symbol(msym, paths);
4896-
t = (ASR::symbol_t*)(load_module(al, global_scope,
4897-
msym, x.base.base.loc, diag, lm, false, paths, lpython, enum_py, copy, sympy,
4898-
[&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); },
4899-
allow_implicit_casting));
4900-
4901-
if (lpython || enum_py || copy || sympy) {
4902-
tmp = nullptr;
4903-
return;
4904-
}
4882+
ASR::symbol_t *t = nullptr; // current_scope->parent->resolve_symbol(msym);
49054883
if (!t) {
4906-
throw SemanticError("The module '" + msym + "' cannot be loaded",
4907-
x.base.base.loc);
4908-
}
4909-
if( msym == "__init__" ) {
4910-
for( auto item: ASRUtils::symbol_symtab(t)->get_scope() ) {
4911-
if( ASR::is_a<ASR::ExternalSymbol_t>(*item.second) ) {
4912-
current_module_dependencies.push_back(al,
4913-
ASR::down_cast<ASR::ExternalSymbol_t>(item.second)->m_module_name);
4884+
std::string rl_path = get_runtime_library_dir();
4885+
std::vector<std::string> paths;
4886+
for (auto &path:import_paths) {
4887+
paths.push_back(path);
4888+
}
4889+
paths.push_back(rl_path);
4890+
paths.push_back(parent_dir);
4891+
4892+
bool lpython, enum_py, copy, sympy;
4893+
set_module_symbol(msym, paths);
4894+
t = (ASR::symbol_t*)(load_module(al, global_scope,
4895+
msym, x.base.base.loc, diag, lm, false, paths, lpython, enum_py, copy, sympy,
4896+
[&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); },
4897+
allow_implicit_casting));
4898+
if (lpython || enum_py || copy || sympy) {
4899+
// TODO: For now we skip lpython import completely. Later on we should note what symbols
4900+
// got imported from it, and give an error message if an annotation is used without
4901+
// importing it.
4902+
tmp = nullptr;
4903+
return;
4904+
}
4905+
if (!t) {
4906+
throw SemanticError("The module '" + msym + "' cannot be loaded",
4907+
x.base.base.loc);
4908+
}
4909+
if( msym == "__init__" ) {
4910+
for( auto item: ASRUtils::symbol_symtab(t)->get_scope() ) {
4911+
if( ASR::is_a<ASR::ExternalSymbol_t>(*item.second) ) {
4912+
current_module_dependencies.push_back(al,
4913+
ASR::down_cast<ASR::ExternalSymbol_t>(item.second)->m_module_name);
4914+
}
49144915
}
4916+
} else {
4917+
current_module_dependencies.push_back(al, s2c(al, msym));
49154918
}
4916-
} else {
4917-
current_module_dependencies.push_back(al, s2c(al, msym));
49184919
}
49194920

49204921
ASR::Module_t *m = ASR::down_cast<ASR::Module_t>(t);
@@ -4929,29 +4930,26 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
49294930
if (x.m_names[j].m_asname) {
49304931
new_sym_name = ASRUtils::get_mangled_name(m, x.m_names[j].m_asname);
49314932
}
4932-
4933-
// FIX: Renamed 't' to 'imported_sym' to avoid shadowing the outer module pointer 't'
4934-
ASR::symbol_t *imported_sym = import_from_module(al, m, current_scope, msym,
4935-
remote_sym, new_sym_name, x.m_names[i].loc, true);
4936-
4933+
ASR::symbol_t *t = import_from_module(al, m, current_scope, msym,
4934+
remote_sym, new_sym_name, x.m_names[i].loc, true);
49374935
if (current_scope->get_scope().find(new_sym_name) != current_scope->get_scope().end()) {
49384936
ASR::symbol_t *old_sym = current_scope->get_scope().find(new_sym_name)->second;
49394937
diag.add(diag::Diagnostic(
49404938
"The symbol '" + new_sym_name + "' imported from " + std::string(m->m_name) +" will shadow the existing symbol '" + new_sym_name + "'",
49414939
diag::Level::Warning, diag::Stage::Semantic, {
49424940
diag::Label("old symbol", {old_sym->base.loc}),
4943-
diag::Label("new symbol", {imported_sym->base.loc}),
4941+
diag::Label("new symbol", {t->base.loc}),
49444942
})
49454943
);
4946-
current_scope->overwrite_symbol(new_sym_name, imported_sym);
4944+
current_scope->overwrite_symbol(new_sym_name, t);
49474945
} else {
4948-
current_scope->add_symbol(new_sym_name, imported_sym);
4946+
current_scope->add_symbol(new_sym_name, t);
49494947
}
49504948
}
4949+
49514950
tmp = nullptr;
49524951
}
49534952

4954-
// --- 2. FIXED: visit_Import (Safe Bools + No Else) ---
49554953
void visit_Import(const AST::Import_t &x) {
49564954
ASR::symbol_t *t = nullptr;
49574955
std::string rl_path = get_runtime_library_dir();
@@ -4966,7 +4964,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
49664964
std::string mod_sym = x.m_names[i].m_name;
49674965
char* alias = x.m_names[i].m_asname;
49684966

4969-
// FIX: Initialize bools
4967+
// Initialize bools
49704968
bool lpython = false, enum_py = false, copy = false, sympy = false;
49714969

49724970
set_module_symbol(mod_sym, paths);
@@ -4981,7 +4979,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
49814979
continue;
49824980
}
49834981
if (!t) {
4984-
throw SemanticError("The module '" + mod_sym + "' cannot be loaded",
4982+
throw SemanticError("The module " + mod_sym + " cannot be loaded",
49854983
x.base.base.loc);
49864984
}
49874985

@@ -5000,23 +4998,98 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
50004998
if (alias) {
50014999
std::string alias_str = std::string(alias);
50025000
if (current_scope->get_symbol(alias_str) == nullptr) {
5003-
// Note: Using the signature that matches your local libasr version
50045001
ASR::asr_t *ext_sym = ASR::make_ExternalSymbol_t(
5005-
al,
5006-
x.base.base.loc,
5007-
current_scope,
5008-
s2c(al, alias_str),
5009-
t,
5010-
s2c(al, mod_sym),
5011-
nullptr,
5012-
0,
5013-
s2c(al, mod_sym),
5014-
ASR::accessType::Public
5002+
al, x.base.base.loc, current_scope,
5003+
s2c(al, alias_str), t, s2c(al, mod_sym),
5004+
nullptr, 0, s2c(al, mod_sym), ASR::accessType::Public
50155005
);
50165006
current_scope->add_symbol(alias_str, ASR::down_cast<ASR::symbol_t>(ext_sym));
50175007
}
50185008
}
5019-
// NO ELSE BLOCK: Standard imports are handled by load_module logic implicitly.
5009+
}
5010+
}
5011+
5012+
void visit_AugAssign(const AST::AugAssign_t &/*x*/) {
5013+
// We skip this in the SymbolTable visitor, but visit it in the BodyVisitor
5014+
}
5015+
5016+
void visit_AnnAssign(const AST::AnnAssign_t &/*x*/) {
5017+
// We skip this in the SymbolTable visitor, but visit it in the BodyVisitor
5018+
}
5019+
5020+
void visit_Assign(const AST::Assign_t &x) {
5021+
/**
5022+
* Type variables have to be put in the symbol table before checking function definitions.
5023+
* This describes a different treatment for Assign in the form of `T = TypeVar('T', ...)`
5024+
*/
5025+
if (x.n_targets == 1 && AST::is_a<AST::Call_t>(*x.m_value)) {
5026+
AST::Call_t *rh = AST::down_cast<AST::Call_t>(x.m_value);
5027+
if (AST::is_a<AST::Name_t>(*rh->m_func)) {
5028+
AST::Name_t *tv = AST::down_cast<AST::Name_t>(rh->m_func);
5029+
std::string f_name = tv->m_id;
5030+
if (strcmp(s2c(al, f_name), "TypeVar") == 0 &&
5031+
rh->n_args > 0 && AST::is_a<AST::ConstantStr_t>(*rh->m_args[0])) {
5032+
if (AST::is_a<AST::Name_t>(*x.m_targets[0])) {
5033+
std::string tvar_name = AST::down_cast<AST::Name_t>(x.m_targets[0])->m_id;
5034+
// Check if the type variable name is a reserved type keyword
5035+
const char* type_list[15]
5036+
= { "list", "set", "dict", "tuple",
5037+
"i8", "i16", "i32", "i64", "f32",
5038+
"f64", "c32", "c64", "str", "bool", "i1"};
5039+
for (int i = 0; i < 15; i++) {
5040+
if (strcmp(s2c(al, tvar_name), type_list[i]) == 0) {
5041+
throw SemanticError(tvar_name + " is a reserved type, consider a different type variable name",
5042+
x.base.base.loc);
5043+
}
5044+
}
5045+
// Check if the type variable is already defined
5046+
if (current_scope->get_scope().find(tvar_name) !=
5047+
current_scope->get_scope().end()) {
5048+
ASR::symbol_t *orig_decl = current_scope->get_symbol(tvar_name);
5049+
throw SemanticError(diag::Diagnostic(
5050+
"Variable " + tvar_name + " is already declared in the same scope",
5051+
diag::Level::Error, diag::Stage::Semantic, {
5052+
diag::Label("original declaration", {orig_decl->base.loc}, false),
5053+
diag::Label("redeclaration", {x.base.base.loc}),
5054+
}));
5055+
}
5056+
5057+
// Build ttype
5058+
Vec<ASR::dimension_t> dims;
5059+
dims.reserve(al, 4);
5060+
5061+
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_TypeParameter_t(al, x.base.base.loc,
5062+
s2c(al, tvar_name)));
5063+
type = ASRUtils::make_Array_t_util(al, x.base.base.loc, type, dims.p, dims.size());
5064+
5065+
ASR::expr_t *value = nullptr;
5066+
ASR::expr_t *init_expr = nullptr;
5067+
ASR::intentType s_intent = ASRUtils::intent_local;
5068+
ASR::storage_typeType storage_type = ASR::storage_typeType::Default;
5069+
ASR::abiType current_procedure_abi_type = ASR::abiType::Source;
5070+
ASR::accessType s_access = ASR::accessType::Public;
5071+
ASR::presenceType s_presence = ASR::presenceType::Required;
5072+
bool value_attr = false;
5073+
5074+
SetChar variable_dependencies_vec;
5075+
variable_dependencies_vec.reserve(al, 1);
5076+
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec, type, init_expr, value);
5077+
// Build the variable and add it to the scope
5078+
ASR::asr_t *v = ASR::make_Variable_t(al, x.base.base.loc, current_scope,
5079+
s2c(al, tvar_name), variable_dependencies_vec.p, variable_dependencies_vec.size(),
5080+
s_intent, init_expr, value, storage_type, type, nullptr, current_procedure_abi_type,
5081+
s_access, s_presence, value_attr, false, false, nullptr, false, false);
5082+
current_scope->add_symbol(tvar_name, ASR::down_cast<ASR::symbol_t>(v));
5083+
5084+
tmp = nullptr;
5085+
5086+
return;
5087+
} else {
5088+
// This error might need to be further elaborated
5089+
throw SemanticError("Type variable must be a variable", x.base.base.loc);
5090+
}
5091+
}
5092+
}
50205093
}
50215094
}
50225095

0 commit comments

Comments
 (0)