From a0d140cb820ddc9f1fcfcd982b2411e3cf305e89 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 20 Nov 2024 22:55:22 -0800 Subject: [PATCH 01/14] Add initial translator code --- enderpy/Cargo.toml | 1 + enderpy/src/cli.rs | 2 + enderpy/src/main.rs | 30 +++ translator/Cargo.toml | 9 + translator/src/lib.rs | 1 + translator/src/translator.rs | 333 +++++++++++++++++++++++++++++++++ typechecker/src/ast_visitor.rs | 4 +- typechecker/src/lib.rs | 6 +- 8 files changed, 381 insertions(+), 5 deletions(-) create mode 100644 translator/Cargo.toml create mode 100644 translator/src/lib.rs create mode 100644 translator/src/translator.rs diff --git a/enderpy/Cargo.toml b/enderpy/Cargo.toml index a4090018..effdc67a 100644 --- a/enderpy/Cargo.toml +++ b/enderpy/Cargo.toml @@ -10,5 +10,6 @@ edition = "2021" [dependencies] enderpy_python_parser = { path = "../parser" , version = "0.1.0" } enderpy_python_type_checker = { path = "../typechecker" , version = "0.1.0" } +corepy_python_translator = { path = "../translator", version = "0.1.0" } clap = { version = "4.5.17", features = ["derive"] } miette.workspace = true diff --git a/enderpy/src/cli.rs b/enderpy/src/cli.rs index 02267048..55aa53be 100644 --- a/enderpy/src/cli.rs +++ b/enderpy/src/cli.rs @@ -23,6 +23,8 @@ pub enum Commands { }, /// Type check Check { path: PathBuf }, + /// Translate to C++ + Translate { path: PathBuf }, /// Symbol table Symbols { path: PathBuf }, diff --git a/enderpy/src/main.rs b/enderpy/src/main.rs index 11fe7b3c..f2d9c59a 100644 --- a/enderpy/src/main.rs +++ b/enderpy/src/main.rs @@ -2,12 +2,14 @@ use std::{ fs::{self, File}, io::{self, Read}, path::{Path, PathBuf}, + sync::Arc, }; use clap::Parser as ClapParser; use cli::{Cli, Commands}; use enderpy_python_parser::{get_row_col_position, parser::parser::Parser, Lexer}; use enderpy_python_type_checker::{build::BuildManager, find_project_root, settings::Settings}; +use corepy_python_translator::translator::CppTranslator; use miette::{bail, IntoDiagnostic, Result}; mod cli; @@ -18,6 +20,7 @@ fn main() -> Result<()> { Commands::Tokenize {} => tokenize(), Commands::Parse { file } => parse(file), Commands::Check { path } => check(path), + Commands::Translate { path } => translate(path), Commands::Watch => watch(), Commands::Symbols { path } => symbols(path), } @@ -134,6 +137,33 @@ fn check(path: &Path) -> Result<()> { Ok(()) } +fn translate(path: &Path) -> Result<()> { + if path.is_dir() { + bail!("Path must be a file"); + } + let root = find_project_root(path); + let python_executable = Some(get_python_executable()?); + let typeshed_path = get_typeshed_path()?; + let settings = Settings { + typeshed_path, + python_executable, + }; + let build_manager = BuildManager::new(settings); + build_manager.build(root); + build_manager.build_one(root, path); + let id = build_manager.paths.get(path).unwrap(); + let file = build_manager.files.get(&id).unwrap(); + let checker = Arc::new(build_manager.type_check(path, &file)); + let mut translator = CppTranslator::new(checker.clone(), &file); + translator.translate(); + println!("{:?}", file.tree); + println!("===="); + println!("{}", translator.output); + println!("===="); + print!("{}", checker.clone().dump_types()); + Ok(()) +} + fn watch() -> Result<()> { todo!() } diff --git a/translator/Cargo.toml b/translator/Cargo.toml new file mode 100644 index 00000000..920036c6 --- /dev/null +++ b/translator/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "corepy_python_translator" +version = "0.1.0" +edition = "2021" + +[dependencies] +enderpy_python_parser = { path = "../parser", version = "0.1.0" } +enderpy_python_type_checker = {path = "../typechecker", version = "0.1.0" } +log = { version = "0.4.17" } \ No newline at end of file diff --git a/translator/src/lib.rs b/translator/src/lib.rs new file mode 100644 index 00000000..a5aac9df --- /dev/null +++ b/translator/src/lib.rs @@ -0,0 +1 @@ +pub mod translator; \ No newline at end of file diff --git a/translator/src/translator.rs b/translator/src/translator.rs new file mode 100644 index 00000000..27b9958d --- /dev/null +++ b/translator/src/translator.rs @@ -0,0 +1,333 @@ +use std::sync::Arc; +use enderpy_python_parser::ast::{self, *}; +use std::error::Error; +use std::fmt::Write; +use log::warn; + +use enderpy_python_type_checker::{ast_visitor::TraversalVisitor, file::EnderpyFile, checker::TypeChecker, types::PythonType}; + +#[derive(Clone, Debug)] +pub struct CppTranslator<'a> { + pub output: String, + indent_level: usize, + checker: Arc>, + file: &'a EnderpyFile, +} + +impl<'a> CppTranslator<'a> { + pub fn new(checker: Arc>, file: &'a EnderpyFile) -> Self { + CppTranslator { + output: "".to_string(), + indent_level: 0, + checker: checker, + file: file, + } + } + + pub fn translate(&mut self) { + for stmt in self.file.tree.body.iter() { + self.visit_stmt(stmt); + } + } + + pub fn write_indent(&mut self) { + write!(self.output, "{}", " ".repeat(self.indent_level)); + } +} + +impl<'a> TraversalVisitor for CppTranslator<'a> { + fn visit_stmt(&mut self, s: &ast::Statement) { + self.write_indent(); + match s { + Statement::ExpressionStatement(e) => self.visit_expr(e), + Statement::Import(i) => self.visit_import(i), + Statement::ImportFrom(i) => self.visit_import_from(i), + Statement::AssignStatement(a) => self.visit_assign(a), + Statement::AnnAssignStatement(a) => self.visit_ann_assign(a), + Statement::AugAssignStatement(a) => self.visit_aug_assign(a), + Statement::Assert(a) => self.visit_assert(a), + Statement::Pass(p) => self.visit_pass(p), + Statement::Delete(d) => self.visit_delete(d), + Statement::ReturnStmt(r) => self.visit_return(r), + Statement::Raise(r) => self.visit_raise(r), + Statement::BreakStmt(b) => self.visit_break(b), + Statement::ContinueStmt(c) => self.visit_continue(c), + Statement::Global(g) => self.visit_global(g), + Statement::Nonlocal(n) => self.visit_nonlocal(n), + Statement::IfStatement(i) => self.visit_if(i), + Statement::WhileStatement(w) => self.visit_while(w), + Statement::ForStatement(f) => self.visit_for(f), + Statement::WithStatement(w) => self.visit_with(w), + Statement::TryStatement(t) => self.visit_try(t), + Statement::TryStarStatement(t) => self.visit_try_star(t), + Statement::FunctionDef(f) => self.visit_function_def(f), + Statement::ClassDef(c) => self.visit_class_def(c), + Statement::MatchStmt(m) => self.visit_match(m), + Statement::AsyncForStatement(f) => self.visit_async_for(f), + Statement::AsyncWithStatement(w) => self.visit_async_with(w), + Statement::AsyncFunctionDef(f) => self.visit_async_function_def(f), + Statement::TypeAlias(a) => self.visit_type_alias(a), + } + } + + fn visit_expr(&mut self, e: &Expression) { + match e { + Expression::Constant(c) => self.visit_constant(c), + Expression::List(l) => self.visit_list(l), + Expression::Tuple(t) => self.visit_tuple(t), + Expression::Dict(d) => self.visit_dict(d), + Expression::Set(s) => self.visit_set(s), + Expression::Name(n) => self.visit_name(n), + Expression::BoolOp(b) => self.visit_bool_op(b), + Expression::UnaryOp(u) => self.visit_unary_op(u), + Expression::BinOp(b) => self.visit_bin_op(b), + Expression::NamedExpr(n) => self.visit_named_expr(n), + Expression::Yield(y) => self.visit_yield(y), + Expression::YieldFrom(y) => self.visit_yield_from(y), + Expression::Starred(s) => self.visit_starred(s), + Expression::Generator(g) => self.visit_generator(g), + Expression::ListComp(l) => self.visit_list_comp(l), + Expression::SetComp(s) => self.visit_set_comp(s), + Expression::DictComp(d) => self.visit_dict_comp(d), + Expression::Attribute(a) => self.visit_attribute(a), + Expression::Subscript(s) => self.visit_subscript(s), + Expression::Slice(s) => self.visit_slice(s), + Expression::Call(c) => self.visit_call(c), + Expression::Await(a) => self.visit_await(a), + Expression::Compare(c) => self.visit_compare(c), + Expression::Lambda(l) => self.visit_lambda(l), + Expression::IfExp(i) => self.visit_if_exp(i), + Expression::JoinedStr(j) => self.visit_joined_str(j), + Expression::FormattedValue(f) => self.visit_formatted_value(f), + } + } + + fn visit_constant(&mut self, constant: &Constant) { + match constant.value { + ConstantValue::None => write!(self.output, "None"), + ConstantValue::Ellipsis => write!(self.output, "..."), + ConstantValue::Bool(_) => write!(self.output, "bool"), + ConstantValue::Str(_) => write!(self.output, "{}", constant.get_value(&self.file.source).to_string()), + ConstantValue::Bytes => write!(self.output, "bytes"), + ConstantValue::Tuple => write!(self.output, "tuple"), + ConstantValue::Int => write!(self.output, "int"), + ConstantValue::Float => write!(self.output, "float"), + ConstantValue::Complex => write!(self.output, "complex"), + /* + Constant::Tuple(elements) => { + let tuple_elements: Vec = elements + .iter() + .map(|elem| self.translate_constant(elem)) + .collect::, _>>()?; + Ok(format!("({})", tuple_elements.join(", "))) + }, + */ + }; + } + + fn visit_import(&mut self, import: &Import) { + for name in import.names.iter() { + if name.name == "torch" { + writeln!(self.output, "#include "); + } + } + } + + fn visit_assign(&mut self, a: &Assign) { + for target in &a.targets { + // let type = self.checker.types. + match target { + Expression::Name(n) => { + println!("XXX {}", n.node.start); + // This loop should only iterate once + for t in self.checker.types.find(n.node.start, n.node.end) { + write!(self.output, "{} ", python_type_to_cpp(&t.val)); + } + self.visit_name(n); + }, + _ => { + self.visit_expr(target); + } + } + } + write!(self.output, " = "); + self.visit_expr(&a.value); + } + + fn visit_name(&mut self, name: &Name) { + write!(self.output, "{}", name.id); + } + + fn visit_call(&mut self, c: &Call) { + self.visit_expr(&c.func); + write!(self.output, "("); + for arg in &c.args { + self.visit_expr(arg); + } + for keyword in &c.keywords { + self.visit_expr(&keyword.value); + } + write!(self.output, ")"); + } + + fn visit_attribute(&mut self, attribute: &Attribute) { + self.visit_expr(&attribute.value); + write!(self.output, "::{}", attribute.attr); + } +} + +fn python_type_to_cpp(python_type: &PythonType) -> String { + match python_type { + PythonType::Class(c) => { + c.details.name.clone() + }, + _ => String::from("") + } +} + +/* +impl CppTranslator { + fn new() -> Self { + CppTranslator::default() + } + + fn translate_ast(&mut self, ast: &ast::Mod) -> Result> { + match ast { + ast::Mod::Module(ast::ModModule { body, .. }) => { + for stmt in body.iter() { + self.translate_stmt(stmt)?; + } + }, + ast::Mod::Interactive(_) => { + }, + ast::Mod::FunctionType(_) => { + }, + ast::Mod::Expression(_) => { + }, + } + + Ok(self.output.clone()) + } + + fn translate_constant(&mut self, constant: &ast::Constant) -> Result> { + match constant { + ast::Constant::Int(n) => Ok(n.to_string()), + ast::Constant::Float(f) => Ok(f.to_string()), + ast::Constant::Complex { real, imag } => Ok(format!("{}+{}j", real, imag)), + ast::Constant::Str(s) => Ok(format!("\"{}\"", s)), + ast::Constant::Bool(b) => Ok(b.to_string()), + ast::Constant::None => Ok("None".to_string()), + ast::Constant::Tuple(elements) => { + let tuple_elements: Vec = elements + .iter() + .map(|elem| self.translate_constant(elem)) + .collect::, _>>()?; + Ok(format!("({})", tuple_elements.join(", "))) + }, + ast::Constant::Ellipsis => Ok("...".to_string()), + ast::Constant::Bytes(bytes) => Ok(format!("b\"{}\"", String::from_utf8_lossy(bytes))), + } + } + + fn translate_expr(&mut self, expr: &ast::Expr) -> Result> { + match expr { + ast::Expr::BoolOp(_) => {}, + ast::Expr::NamedExpr(_) => {}, + ast::Expr::BinOp(_) => {}, + ast::Expr::UnaryOp(_) => {}, + ast::Expr::Lambda(_) => {}, + ast::Expr::IfExp(_) => {}, + ast::Expr::Dict(_) => {}, + ast::Expr::Set(_) => {}, + ast::Expr::ListComp(_) => {}, + ast::Expr::SetComp(_) => {}, + ast::Expr::DictComp(_) => {}, + ast::Expr::GeneratorExp(_) => {}, + ast::Expr::Await(_) => {}, + ast::Expr::Yield(_) => {}, + ast::Expr::YieldFrom(_) => {}, + ast::Expr::Compare(_) => {}, + ast::Expr::Call(_) => {}, + ast::Expr::FormattedValue(_) => {}, + ast::Expr::JoinedStr(_) => {}, + ast::Expr::Constant(c) => { + let s = self.translate_constant(&c.value)?; + write!(self.output, "{}", s); + }, + ast::Expr::Attribute(_) => {}, + ast::Expr::Subscript(_) => {}, + ast::Expr::Starred(_) => {}, + ast::Expr::Name(name) => { + write!(self.output, "{}", name.id); + }, + ast::Expr::List(_) => {}, + ast::Expr::Tuple(_) => {}, + ast::Expr::Slice(_) => {}, + } + Ok(self.output.clone()) + } + + fn translate_stmt(&mut self, stmt: &ast::Stmt) -> Result> { + match stmt { + ast::Stmt::Assign(assign) => { + self.write_indent()?; + self.translate_expr(&assign.targets[0])?; + write!(self.output, " = "); + self.translate_expr(&assign.value)?; + write!(self.output, "\n"); + }, + ast::Stmt::FunctionDef(function) => { + writeln!(self.output, "void {}() {{", function.name.to_string()); + self.indent_level += 1; + for stmt in function.body.iter() { + self.translate_stmt(stmt)?; + } + self.indent_level -= 1; + writeln!(self.output, "}}"); + }, + ast::Stmt::AsyncFunctionDef(_) => {}, + ast::Stmt::ClassDef(_) => {}, + ast::Stmt::Return(_) => {}, + ast::Stmt::Delete(_) => {}, + ast::Stmt::TypeAlias(_) => {}, + ast::Stmt::AugAssign(_) => {}, + ast::Stmt::AnnAssign(_) => {}, + ast::Stmt::For(_) => {}, + ast::Stmt::AsyncFor(_) => {}, + ast::Stmt::While(_) => {}, + ast::Stmt::If(_) => {}, + ast::Stmt::With(_) => {}, + ast::Stmt::AsyncWith(_) => {}, + ast::Stmt::Match(_) => {}, + ast::Stmt::Raise(_) => {}, + ast::Stmt::Try(_) => {}, + ast::Stmt::TryStar(_) => {}, + ast::Stmt::Assert(_) => {}, + ast::Stmt::Import(imp) => { + if imp.names[0].name.to_string() == "torch" { + writeln!(self.output, "#include "); + } + }, + ast::Stmt::ImportFrom(_) => {}, + ast::Stmt::Global(_) => {}, + ast::Stmt::Nonlocal(_) => {}, + ast::Stmt::Expr(_) => {}, + ast::Stmt::Pass(_) => {}, + ast::Stmt::Break(_) => {}, + ast::Stmt::Continue(_) => {}, + } + Ok(self.output.clone()) + } + + fn write_indent(&mut self) -> Result<(), Box> { + write!(self.output, "{}", " ".repeat(self.indent_level))?; + Ok(()) + } +} + +pub fn python_to_cpp(python_ast: &ast::Mod) -> Result> { + let mut translator = CppTranslator::new(); + translator.translate_ast(python_ast) +} + +*/ diff --git a/typechecker/src/ast_visitor.rs b/typechecker/src/ast_visitor.rs index e35258ff..f7101fc2 100644 --- a/typechecker/src/ast_visitor.rs +++ b/typechecker/src/ast_visitor.rs @@ -323,7 +323,7 @@ pub trait TraversalVisitor { } fn visit_assign(&mut self, _a: &Assign) { - todo!() + // todo!() } fn visit_ann_assign(&mut self, _a: &AnnAssign) { @@ -339,7 +339,7 @@ pub trait TraversalVisitor { } fn visit_pass(&mut self, _p: &Pass) { - todo!() + // todo!() } fn visit_delete(&mut self, _d: &Delete) { diff --git a/typechecker/src/lib.rs b/typechecker/src/lib.rs index c6396fef..ae78a3d3 100644 --- a/typechecker/src/lib.rs +++ b/typechecker/src/lib.rs @@ -1,7 +1,7 @@ use std::path::Path; -mod ast_visitor; -mod file; +pub mod ast_visitor; +pub mod file; mod ruff_python_import_resolver; mod symbol_table; @@ -11,7 +11,7 @@ pub mod diagnostic; pub mod semantic_analyzer; pub mod settings; pub mod type_evaluator; -mod types; +pub mod types; pub(crate) mod builtins { pub const LIST_TYPE: &str = "list"; From a49df8cc75c7075d31c5d8fee10b0dabd33e07c2 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 22 Nov 2024 00:15:48 -0800 Subject: [PATCH 02/14] implement argument type checking --- parser/src/ast.rs | 5 +- translator/src/translator.rs | 226 +++++++++++------------------------ typechecker/src/checker.rs | 9 ++ 3 files changed, 82 insertions(+), 158 deletions(-) diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 9feb1129..848d19bb 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -425,7 +425,10 @@ impl Constant { } else { Cow::Borrowed("false") } - } + }, + ConstantValue::Int => Cow::Borrowed( + &source[self.node.start as usize..self.node.end as usize], + ), _ => todo!("Call the parser and get the value"), } } diff --git a/translator/src/translator.rs b/translator/src/translator.rs index 27b9958d..1d63e94e 100644 --- a/translator/src/translator.rs +++ b/translator/src/translator.rs @@ -1,10 +1,11 @@ use std::sync::Arc; use enderpy_python_parser::ast::{self, *}; +use enderpy_python_parser::parser::parser::intern_lookup; use std::error::Error; use std::fmt::Write; use log::warn; -use enderpy_python_type_checker::{ast_visitor::TraversalVisitor, file::EnderpyFile, checker::TypeChecker, types::PythonType}; +use enderpy_python_type_checker::{types, ast_visitor::TraversalVisitor, file::EnderpyFile, checker::TypeChecker, types::PythonType}; #[derive(Clone, Debug)] pub struct CppTranslator<'a> { @@ -33,6 +34,15 @@ impl<'a> CppTranslator<'a> { pub fn write_indent(&mut self) { write!(self.output, "{}", " ".repeat(self.indent_level)); } + + fn check_type(&self, node: &Node, typ: &PythonType) { + assert!( + self.checker.get_type(node) == *typ, + "type error at {}, expected {} got {}", + self.file.get_position(node.start, node.end), + typ, self.checker.get_type(node) + ); + } } impl<'a> TraversalVisitor for CppTranslator<'a> { @@ -42,7 +52,10 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { Statement::ExpressionStatement(e) => self.visit_expr(e), Statement::Import(i) => self.visit_import(i), Statement::ImportFrom(i) => self.visit_import_from(i), - Statement::AssignStatement(a) => self.visit_assign(a), + Statement::AssignStatement(a) => { + self.visit_assign(a); + writeln!(self.output, ";"); + }, Statement::AnnAssignStatement(a) => self.visit_ann_assign(a), Statement::AugAssignStatement(a) => self.visit_aug_assign(a), Statement::Assert(a) => self.visit_assert(a), @@ -107,10 +120,10 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { ConstantValue::None => write!(self.output, "None"), ConstantValue::Ellipsis => write!(self.output, "..."), ConstantValue::Bool(_) => write!(self.output, "bool"), - ConstantValue::Str(_) => write!(self.output, "{}", constant.get_value(&self.file.source).to_string()), + ConstantValue::Str(_) => write!(self.output, "\"{}\"", constant.get_value(&self.file.source).to_string()), ConstantValue::Bytes => write!(self.output, "bytes"), ConstantValue::Tuple => write!(self.output, "tuple"), - ConstantValue::Int => write!(self.output, "int"), + ConstantValue::Int => write!(self.output, "{}", constant.get_value(&self.file.source).to_string()), ConstantValue::Float => write!(self.output, "float"), ConstantValue::Complex => write!(self.output, "complex"), /* @@ -159,14 +172,43 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } fn visit_call(&mut self, c: &Call) { + let typ = self.checker.get_type(&c.func.get_node()); self.visit_expr(&c.func); write!(self.output, "("); - for arg in &c.args { - self.visit_expr(arg); - } - for keyword in &c.keywords { - self.visit_expr(&keyword.value); + match typ { + PythonType::Callable(callable) => { + let mut num_pos_args = 0; + // First check all the positional args + for (i, arg) in callable.signature.iter().enumerate() { + match arg { + types::CallableArgs::Args(t) => { + self.check_type(&c.args[i].get_node(), t); + num_pos_args = i; + }, + types::CallableArgs::Positional(t) => { + break; + }, + _ => {} + } + } + // Then check all the star args if there are any + if num_pos_args < c.args.len() { + write!(self.output, "{{"); + for (i, arg) in c.args[num_pos_args..].iter().enumerate() { + self.check_type(&arg.get_node(), callable.signature[num_pos_args].get_type()); + if i != 0 { + write!(self.output, ", "); + } + self.visit_expr(arg); + } + write!(self.output, "}}"); + } + }, + _ => {} } + // for keyword in &c.keywords { + // self.visit_expr(&keyword.value); + // } write!(self.output, ")"); } @@ -174,6 +216,20 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { self.visit_expr(&attribute.value); write!(self.output, "::{}", attribute.attr); } + + fn visit_function_def(&mut self, f: &Arc) { + write!(self.output, "void {}(", intern_lookup(f.name)); + for arg in f.args.args.iter() { + write!(self.output, "{} {}", python_type_to_cpp(&self.checker.get_type(&arg.node)), arg.arg); + } + writeln!(self.output, ") {{"); + self.indent_level += 1; + for stmt in &f.body { + self.visit_stmt(stmt); + } + self.indent_level -= 1; + writeln!(self.output, "}}"); + } } fn python_type_to_cpp(python_type: &PythonType) -> String { @@ -181,153 +237,9 @@ fn python_type_to_cpp(python_type: &PythonType) -> String { PythonType::Class(c) => { c.details.name.clone() }, - _ => String::from("") - } -} - -/* -impl CppTranslator { - fn new() -> Self { - CppTranslator::default() - } - - fn translate_ast(&mut self, ast: &ast::Mod) -> Result> { - match ast { - ast::Mod::Module(ast::ModModule { body, .. }) => { - for stmt in body.iter() { - self.translate_stmt(stmt)?; - } - }, - ast::Mod::Interactive(_) => { - }, - ast::Mod::FunctionType(_) => { - }, - ast::Mod::Expression(_) => { - }, - } - - Ok(self.output.clone()) - } - - fn translate_constant(&mut self, constant: &ast::Constant) -> Result> { - match constant { - ast::Constant::Int(n) => Ok(n.to_string()), - ast::Constant::Float(f) => Ok(f.to_string()), - ast::Constant::Complex { real, imag } => Ok(format!("{}+{}j", real, imag)), - ast::Constant::Str(s) => Ok(format!("\"{}\"", s)), - ast::Constant::Bool(b) => Ok(b.to_string()), - ast::Constant::None => Ok("None".to_string()), - ast::Constant::Tuple(elements) => { - let tuple_elements: Vec = elements - .iter() - .map(|elem| self.translate_constant(elem)) - .collect::, _>>()?; - Ok(format!("({})", tuple_elements.join(", "))) - }, - ast::Constant::Ellipsis => Ok("...".to_string()), - ast::Constant::Bytes(bytes) => Ok(format!("b\"{}\"", String::from_utf8_lossy(bytes))), - } - } - - fn translate_expr(&mut self, expr: &ast::Expr) -> Result> { - match expr { - ast::Expr::BoolOp(_) => {}, - ast::Expr::NamedExpr(_) => {}, - ast::Expr::BinOp(_) => {}, - ast::Expr::UnaryOp(_) => {}, - ast::Expr::Lambda(_) => {}, - ast::Expr::IfExp(_) => {}, - ast::Expr::Dict(_) => {}, - ast::Expr::Set(_) => {}, - ast::Expr::ListComp(_) => {}, - ast::Expr::SetComp(_) => {}, - ast::Expr::DictComp(_) => {}, - ast::Expr::GeneratorExp(_) => {}, - ast::Expr::Await(_) => {}, - ast::Expr::Yield(_) => {}, - ast::Expr::YieldFrom(_) => {}, - ast::Expr::Compare(_) => {}, - ast::Expr::Call(_) => {}, - ast::Expr::FormattedValue(_) => {}, - ast::Expr::JoinedStr(_) => {}, - ast::Expr::Constant(c) => { - let s = self.translate_constant(&c.value)?; - write!(self.output, "{}", s); - }, - ast::Expr::Attribute(_) => {}, - ast::Expr::Subscript(_) => {}, - ast::Expr::Starred(_) => {}, - ast::Expr::Name(name) => { - write!(self.output, "{}", name.id); - }, - ast::Expr::List(_) => {}, - ast::Expr::Tuple(_) => {}, - ast::Expr::Slice(_) => {}, - } - Ok(self.output.clone()) - } - - fn translate_stmt(&mut self, stmt: &ast::Stmt) -> Result> { - match stmt { - ast::Stmt::Assign(assign) => { - self.write_indent()?; - self.translate_expr(&assign.targets[0])?; - write!(self.output, " = "); - self.translate_expr(&assign.value)?; - write!(self.output, "\n"); - }, - ast::Stmt::FunctionDef(function) => { - writeln!(self.output, "void {}() {{", function.name.to_string()); - self.indent_level += 1; - for stmt in function.body.iter() { - self.translate_stmt(stmt)?; - } - self.indent_level -= 1; - writeln!(self.output, "}}"); - }, - ast::Stmt::AsyncFunctionDef(_) => {}, - ast::Stmt::ClassDef(_) => {}, - ast::Stmt::Return(_) => {}, - ast::Stmt::Delete(_) => {}, - ast::Stmt::TypeAlias(_) => {}, - ast::Stmt::AugAssign(_) => {}, - ast::Stmt::AnnAssign(_) => {}, - ast::Stmt::For(_) => {}, - ast::Stmt::AsyncFor(_) => {}, - ast::Stmt::While(_) => {}, - ast::Stmt::If(_) => {}, - ast::Stmt::With(_) => {}, - ast::Stmt::AsyncWith(_) => {}, - ast::Stmt::Match(_) => {}, - ast::Stmt::Raise(_) => {}, - ast::Stmt::Try(_) => {}, - ast::Stmt::TryStar(_) => {}, - ast::Stmt::Assert(_) => {}, - ast::Stmt::Import(imp) => { - if imp.names[0].name.to_string() == "torch" { - writeln!(self.output, "#include "); - } - }, - ast::Stmt::ImportFrom(_) => {}, - ast::Stmt::Global(_) => {}, - ast::Stmt::Nonlocal(_) => {}, - ast::Stmt::Expr(_) => {}, - ast::Stmt::Pass(_) => {}, - ast::Stmt::Break(_) => {}, - ast::Stmt::Continue(_) => {}, - } - Ok(self.output.clone()) - } - - fn write_indent(&mut self) -> Result<(), Box> { - write!(self.output, "{}", " ".repeat(self.indent_level))?; - Ok(()) + PythonType::Instance(i) => { + i.class_type.details.name.clone() + }, + _ => String::from(format!("", python_type)) } } - -pub fn python_to_cpp(python_ast: &ast::Mod) -> Result> { - let mut translator = CppTranslator::new(); - translator.translate_ast(python_ast) -} - -*/ diff --git a/typechecker/src/checker.rs b/typechecker/src/checker.rs index c58c8a84..41fefa8f 100644 --- a/typechecker/src/checker.rs +++ b/typechecker/src/checker.rs @@ -136,6 +136,15 @@ impl<'a> TypeChecker<'a> { str } + + pub fn get_type(&self, node: &ast::Node) -> PythonType { + for r in self.types.find(node.start, node.end) { + if r.start == node.start && r.stop == node.end { + return r.val.clone(); + } + } + return PythonType::Unknown; + } } #[allow(unused)] impl<'a> TraversalVisitor for TypeChecker<'a> { From e8ae46244922e4bbb410c027884f2c866b04f066 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 26 Nov 2024 03:17:43 -0800 Subject: [PATCH 03/14] beginning of for loops --- translator/src/translator.rs | 44 +++++++++++++++++++++++++++++-- typechecker/src/checker.rs | 6 +++-- typechecker/src/type_evaluator.rs | 34 ++++++++++++++++++++++-- 3 files changed, 78 insertions(+), 6 deletions(-) diff --git a/translator/src/translator.rs b/translator/src/translator.rs index 1d63e94e..61e35df9 100644 --- a/translator/src/translator.rs +++ b/translator/src/translator.rs @@ -151,7 +151,6 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { // let type = self.checker.types. match target { Expression::Name(n) => { - println!("XXX {}", n.node.start); // This loop should only iterate once for t in self.checker.types.find(n.node.start, n.node.end) { write!(self.output, "{} ", python_type_to_cpp(&t.val)); @@ -171,6 +170,12 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { write!(self.output, "{}", name.id); } + fn visit_bin_op(&mut self, b: &BinOp) { + self.visit_expr(&b.left); + write!(self.output, " {} ", &b.op); + self.visit_expr(&b.right); + } + fn visit_call(&mut self, c: &Call) { let typ = self.checker.get_type(&c.func.get_node()); self.visit_expr(&c.func); @@ -204,7 +209,9 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { write!(self.output, "}}"); } }, - _ => {} + _ => { + println!("Shouldn't hit this code path"); + } } // for keyword in &c.keywords { // self.visit_expr(&keyword.value); @@ -230,6 +237,39 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { self.indent_level -= 1; writeln!(self.output, "}}"); } + + fn visit_for(&mut self, f: &For) { + let mut bound = None; + match &f.iter { + Expression::Call(c) => { + match &c.func { + Expression::Name(n) => { + if n.id == "range" { + bound = Some(c.args[0].clone()); + } + } + _ => {} + } + }, + _ => {} + } + write!(self.output, "for(int "); + self.visit_expr(&f.target); + write!(self.output, " = 0; "); + self.visit_expr(&f.target); + write!(self.output, " < "); + self.visit_expr(&bound.unwrap()); + write!(self.output, "; ++"); + self.visit_expr(&f.target); + writeln!(self.output, ") {{"); + self.indent_level += 1; + for stmt in &f.body { + self.visit_stmt(stmt); + } + self.indent_level -= 1; + self.write_indent(); + writeln!(self.output, "}}"); + } } fn python_type_to_cpp(python_type: &PythonType) -> String { diff --git a/typechecker/src/checker.rs b/typechecker/src/checker.rs index 41fefa8f..b633d9f4 100644 --- a/typechecker/src/checker.rs +++ b/typechecker/src/checker.rs @@ -465,8 +465,10 @@ impl<'a> TraversalVisitor for TypeChecker<'a> { } fn visit_bin_op(&mut self, b: &BinOp) { - let l_type = self.infer_expr_type(&b.left); - let r_type = self.infer_expr_type(&b.right); + // let l_type = self.infer_expr_type(&b.left); + self.visit_expr(&b.left); + // let r_type = self.infer_expr_type(&b.right); + self.visit_expr(&b.right); } fn visit_named_expr(&mut self, _n: &NamedExpression) { diff --git a/typechecker/src/type_evaluator.rs b/typechecker/src/type_evaluator.rs index 3adda5bf..988eee86 100755 --- a/typechecker/src/type_evaluator.rs +++ b/typechecker/src/type_evaluator.rs @@ -869,9 +869,39 @@ impl<'a> TypeEvaluator<'a> { &iter_method_type.type_parameters, &iter_method_type.specialized, ) - } + }, + PythonType::Class(class_type) => { + let iter_method = match self.lookup_on_class( + &symbol_table, + &class_type, + "__iter__", + ) { + Some(PythonType::Callable(c)) => c, + Some(other) => panic!("iter method was not callable: {}", other), + None => panic!("next method not found"), + }; + let Some(iter_method_type) = &iter_method.return_type.class() + else { + panic!("iter method return type is not class"); + }; + let next_method = match self.lookup_on_class( + &symbol_table, + &iter_method_type, + "__next__", + ) { + Some(PythonType::Callable(c)) => c, + Some(other) => panic!("next method was not callable: {}", other), + None => panic!("next method not found"), + }; + self.resolve_generics( + &next_method.return_type, + &iter_method_type.type_parameters, + &iter_method_type.specialized, + ) + // PythonType::Unknown + }, _ => { - error!("iterating over a {} is not defined", iter_type); + error!("iterating over a {:?} is not defined", iter_type); PythonType::Unknown } } From 07cccb190914e7397261453744fa0f53be561b56 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 26 Nov 2024 07:50:22 -0800 Subject: [PATCH 04/14] add symbol table lookup for assignments --- translator/src/translator.rs | 32 ++++++++++++++++++++++++++++---- typechecker/src/checker.rs | 6 +++++- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/translator/src/translator.rs b/translator/src/translator.rs index 61e35df9..167dbf6b 100644 --- a/translator/src/translator.rs +++ b/translator/src/translator.rs @@ -13,6 +13,8 @@ pub struct CppTranslator<'a> { indent_level: usize, checker: Arc>, file: &'a EnderpyFile, + current_scope: u32, + prev_scope: u32, } impl<'a> CppTranslator<'a> { @@ -22,6 +24,8 @@ impl<'a> CppTranslator<'a> { indent_level: 0, checker: checker, file: file, + current_scope: 0, + prev_scope: 0, } } @@ -43,6 +47,16 @@ impl<'a> CppTranslator<'a> { typ, self.checker.get_type(node) ); } + + fn enter_scope(&mut self, pos: u32) { + let symbol_table = self.checker.get_symbol_table(); + self.prev_scope = self.current_scope; + self.current_scope = symbol_table.get_scope(pos); + } + + fn leave_scope(&mut self) { + self.current_scope = self.prev_scope; + } } impl<'a> TraversalVisitor for CppTranslator<'a> { @@ -147,14 +161,22 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } fn visit_assign(&mut self, a: &Assign) { + let symbol_table = self.checker.get_symbol_table(); for target in &a.targets { // let type = self.checker.types. match target { Expression::Name(n) => { - // This loop should only iterate once - for t in self.checker.types.find(n.node.start, n.node.end) { - write!(self.output, "{} ", python_type_to_cpp(&t.val)); - } + let node = symbol_table.lookup_in_scope(&n.id, self.current_scope); + match node { + Some(node) => { + let path = node.declarations[0].declaration_path(); + if path.node == n.node { + let typ = self.checker.get_type(&n.node); + write!(self.output, "{} ", python_type_to_cpp(&typ)); + } + }, + None => {}, + }; self.visit_name(n); }, _ => { @@ -225,6 +247,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } fn visit_function_def(&mut self, f: &Arc) { + self.enter_scope(f.node.start); write!(self.output, "void {}(", intern_lookup(f.name)); for arg in f.args.args.iter() { write!(self.output, "{} {}", python_type_to_cpp(&self.checker.get_type(&arg.node)), arg.arg); @@ -236,6 +259,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } self.indent_level -= 1; writeln!(self.output, "}}"); + self.leave_scope(); } fn visit_for(&mut self, f: &For) { diff --git a/typechecker/src/checker.rs b/typechecker/src/checker.rs index b633d9f4..e26e9616 100644 --- a/typechecker/src/checker.rs +++ b/typechecker/src/checker.rs @@ -7,7 +7,7 @@ use enderpy_python_parser::parser::parser::intern_lookup; use super::{type_evaluator::TypeEvaluator, types::PythonType}; use crate::build::BuildManager; -use crate::symbol_table::Id; +use crate::symbol_table::{Id, SymbolTable}; use crate::types::ModuleRef; use crate::{ast_visitor::TraversalVisitor, diagnostic::CharacterSpan}; use rust_lapper::{Interval, Lapper}; @@ -145,6 +145,10 @@ impl<'a> TypeChecker<'a> { } return PythonType::Unknown; } + + pub fn get_symbol_table(&self) -> Arc { + return self.build_manager.get_symbol_table_by_id(&self.id); + } } #[allow(unused)] impl<'a> TraversalVisitor for TypeChecker<'a> { From 5bb013b3f03f9d032829e666ca18b27efacc0dbe Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 2 Dec 2024 00:02:36 -0800 Subject: [PATCH 05/14] make sure types are looked up in the module symbol table --- parser/src/ast.rs | 3 ++ translator/src/translator.rs | 84 ++++++++++++++++++++++++------- typechecker/src/checker.rs | 5 +- typechecker/src/lib.rs | 15 +++++- typechecker/src/type_evaluator.rs | 70 ++++++++++++++++++++++---- 5 files changed, 145 insertions(+), 32 deletions(-) diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 848d19bb..1ce0ef0a 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -429,6 +429,9 @@ impl Constant { ConstantValue::Int => Cow::Borrowed( &source[self.node.start as usize..self.node.end as usize], ), + ConstantValue::Float => Cow::Borrowed( + &source[self.node.start as usize..self.node.end as usize], + ), _ => todo!("Call the parser and get the value"), } } diff --git a/translator/src/translator.rs b/translator/src/translator.rs index 167dbf6b..8e1826f7 100644 --- a/translator/src/translator.rs +++ b/translator/src/translator.rs @@ -6,6 +6,7 @@ use std::fmt::Write; use log::warn; use enderpy_python_type_checker::{types, ast_visitor::TraversalVisitor, file::EnderpyFile, checker::TypeChecker, types::PythonType}; +use enderpy_python_type_checker::{get_module_name, symbol_table}; #[derive(Clone, Debug)] pub struct CppTranslator<'a> { @@ -49,7 +50,7 @@ impl<'a> CppTranslator<'a> { } fn enter_scope(&mut self, pos: u32) { - let symbol_table = self.checker.get_symbol_table(); + let symbol_table = self.checker.get_symbol_table(None); self.prev_scope = self.current_scope; self.current_scope = symbol_table.get_scope(pos); } @@ -57,6 +58,36 @@ impl<'a> CppTranslator<'a> { fn leave_scope(&mut self) { self.current_scope = self.prev_scope; } + + fn python_type_to_cpp(&self, python_type: &PythonType) -> String { + let details; + match python_type { + PythonType::Class(c) => { + details = &c.details; + }, + PythonType::Instance(i) => { + details = &i.class_type.details; + }, + _ => { + return String::from(format!("", python_type)); + } + }; + // If the current symbol table already contains details.name, + // we do not need to qualify it, otherwise qualify it with the module name, + // unless it is a builtin type in which case we do not qualify it. + let symbol_table = self.checker.get_symbol_table(None); + match symbol_table.lookup_in_scope(&details.name, self.current_scope) { + Some(_) => details.name.to_string(), + None => { + let symbol_table = self.checker.get_symbol_table(Some(details.declaration_path.symbol_table_id)); + if symbol_table.file_path.as_path().ends_with("builtins.pyi") { + details.name.to_string() + } else { + format!("{}::{}", get_module_name(symbol_table.file_path.as_path()), &details.name) + } + } + } + } } impl<'a> TraversalVisitor for CppTranslator<'a> { @@ -138,7 +169,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { ConstantValue::Bytes => write!(self.output, "bytes"), ConstantValue::Tuple => write!(self.output, "tuple"), ConstantValue::Int => write!(self.output, "{}", constant.get_value(&self.file.source).to_string()), - ConstantValue::Float => write!(self.output, "float"), + ConstantValue::Float => write!(self.output, "{}", constant.get_value(&self.file.source).to_string()), ConstantValue::Complex => write!(self.output, "complex"), /* Constant::Tuple(elements) => { @@ -161,7 +192,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } fn visit_assign(&mut self, a: &Assign) { - let symbol_table = self.checker.get_symbol_table(); + let symbol_table = self.checker.get_symbol_table(None); for target in &a.targets { // let type = self.checker.types. match target { @@ -172,7 +203,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { let path = node.declarations[0].declaration_path(); if path.node == n.node { let typ = self.checker.get_type(&n.node); - write!(self.output, "{} ", python_type_to_cpp(&typ)); + write!(self.output, "{} ", self.python_type_to_cpp(&typ)); } }, None => {}, @@ -210,6 +241,10 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { match arg { types::CallableArgs::Args(t) => { self.check_type(&c.args[i].get_node(), t); + if i != 0 { + write!(self.output, ", "); + } + self.visit_expr(&c.args[i]); num_pos_args = i; }, types::CallableArgs::Positional(t) => { @@ -243,14 +278,38 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { fn visit_attribute(&mut self, attribute: &Attribute) { self.visit_expr(&attribute.value); - write!(self.output, "::{}", attribute.attr); + match &attribute.value { + Expression::Name(n) => { + let symbol_table = self.checker.get_symbol_table(None); + match symbol_table.lookup_in_scope(&n.id, self.current_scope) { + Some(entry) => { + match entry.last_declaration() { + symbol_table::Declaration::Alias(a) => { + write!(self.output, "::{}", attribute.attr); + return + }, + _ => {} + } + }, + None => {}, + } + }, + _ => {} + } + write!(self.output, ".{}", attribute.attr); + } + + fn visit_return(&mut self, _r: &Return) { } fn visit_function_def(&mut self, f: &Arc) { self.enter_scope(f.node.start); write!(self.output, "void {}(", intern_lookup(f.name)); - for arg in f.args.args.iter() { - write!(self.output, "{} {}", python_type_to_cpp(&self.checker.get_type(&arg.node)), arg.arg); + for (i, arg) in f.args.args.iter().enumerate() { + if i != 0 { + write!(self.output, ", "); + } + write!(self.output, "{} {}", self.python_type_to_cpp(&self.checker.get_type(&arg.node)), arg.arg); } writeln!(self.output, ") {{"); self.indent_level += 1; @@ -296,14 +355,3 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } } -fn python_type_to_cpp(python_type: &PythonType) -> String { - match python_type { - PythonType::Class(c) => { - c.details.name.clone() - }, - PythonType::Instance(i) => { - i.class_type.details.name.clone() - }, - _ => String::from(format!("", python_type)) - } -} diff --git a/typechecker/src/checker.rs b/typechecker/src/checker.rs index e26e9616..f37aee9f 100644 --- a/typechecker/src/checker.rs +++ b/typechecker/src/checker.rs @@ -146,8 +146,9 @@ impl<'a> TypeChecker<'a> { return PythonType::Unknown; } - pub fn get_symbol_table(&self) -> Arc { - return self.build_manager.get_symbol_table_by_id(&self.id); + pub fn get_symbol_table(&self, id: Option) -> Arc { + let id = id.unwrap_or(self.id); + return self.build_manager.get_symbol_table_by_id(&id); } } #[allow(unused)] diff --git a/typechecker/src/lib.rs b/typechecker/src/lib.rs index ae78a3d3..79ecc988 100644 --- a/typechecker/src/lib.rs +++ b/typechecker/src/lib.rs @@ -3,7 +3,7 @@ use std::path::Path; pub mod ast_visitor; pub mod file; mod ruff_python_import_resolver; -mod symbol_table; +pub mod symbol_table; pub mod build; pub mod checker; @@ -40,5 +40,16 @@ pub fn find_project_root(path: &Path) -> &Path { } pub fn get_module_name(path: &Path) -> String { - path.to_str().unwrap().replace(['/', '\\'], ".") + // First we strip .pyi and / or __init__.pyi from the end + let mut s = path.to_str().unwrap(); + s = match s.strip_suffix("/__init__.pyi") { + Some(new) => new, + None => s + }; + s = match s.strip_suffix(".pyi") { + Some(new) => new, + None => s + }; + // And then we replace the slashes with . + s.replace(['/', '\\'], ".") } diff --git a/typechecker/src/type_evaluator.rs b/typechecker/src/type_evaluator.rs index 988eee86..50d0c748 100755 --- a/typechecker/src/type_evaluator.rs +++ b/typechecker/src/type_evaluator.rs @@ -22,6 +22,7 @@ use super::{ }, }; use crate::{ + get_module_name, build::BuildManager, semantic_analyzer::get_member_access_info, symbol_table::{self, Class, Declaration, DeclarationPath, Id, SymbolTable, SymbolTableNode}, @@ -676,7 +677,18 @@ impl<'a> TypeEvaluator<'a> { // TODO: check if other binary operators are allowed _ => todo!(), } - } + }, + Expression::Attribute(a) => { + match &a.value { + Expression::Name(n) => { + let Some(typ) = self.lookup_on_module(symbol_table, scope_id, &n.id, &a.attr) else { + return PythonType::Unknown; + }; + return typ; + }, + _ => todo!(), + }; + }, _ => PythonType::Unknown, }; @@ -1034,15 +1046,20 @@ impl<'a> TypeEvaluator<'a> { PythonType::Unknown } None => { - let Some(ref resolved_import) = a.import_result else { - trace!("import result not found"); - return PythonType::Unknown; - }; - - let module_id = resolved_import.resolved_ids.first().unwrap(); - return PythonType::Module(ModuleRef { - module_id: *module_id, - }); + match &a.import_node { + Some(i) => { + let module_name = &i.names[0].name; + let Some(module_symbol_table) = self.get_symbol_table_for_module(&a, module_name) else { + return PythonType::Unknown; + }; + return PythonType::Module(ModuleRef { + module_id: module_symbol_table.id, + }); + }, + None => { + return PythonType::Unknown; + } + } } } } @@ -1547,6 +1564,39 @@ impl<'a> TypeEvaluator<'a> { symbol.map(|node| self.get_symbol_type(node, symbol_table, None)) } + /// Find a type inside a Python module + fn lookup_on_module( + &self, + symbol_table: &SymbolTable, + scope_id: u32, + module_name: &str, + attr: &str, + ) -> Option { + // See if the module is in the symbol table + let symbol_table_entry = symbol_table.lookup_in_scope(module_name, scope_id)?; + match symbol_table_entry.last_declaration() { + Declaration::Alias(a) => { + let module_symbol_table = self.get_symbol_table_for_module(&a, module_name)?; + return Some(self.get_name_type(attr, None, &module_symbol_table, 0)); + } + _ => {} + }; + None + } + + fn get_symbol_table_for_module(&self, alias: &symbol_table::Alias, module_name: &str) -> Option> { + let Some(ref resolved_import) = alias.import_result else { + return None; + }; + for id in resolved_import.resolved_ids.iter() { + let module_symbol_table = self.get_symbol_table(id); + if module_name == get_module_name(module_symbol_table.file_path.as_path()) { + return Some(module_symbol_table); + } + } + return None; + } + fn get_function_signature( &self, arguments: &ast::Arguments, From 81a3e79ee2e12e46bc93c61a838e6e113a81a999 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 7 Dec 2024 15:40:20 -0800 Subject: [PATCH 06/14] Add classes and member variables --- translator/src/translator.rs | 111 +++++++++++++++++++++++++++--- typechecker/src/checker.rs | 2 +- typechecker/src/type_evaluator.rs | 38 +++++++--- 3 files changed, 133 insertions(+), 18 deletions(-) diff --git a/translator/src/translator.rs b/translator/src/translator.rs index 8e1826f7..72355fa7 100644 --- a/translator/src/translator.rs +++ b/translator/src/translator.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::collections::HashMap; use enderpy_python_parser::ast::{self, *}; use enderpy_python_parser::parser::parser::intern_lookup; use std::error::Error; @@ -16,6 +17,9 @@ pub struct CppTranslator<'a> { file: &'a EnderpyFile, current_scope: u32, prev_scope: u32, + // Member variables of the current class + class_members: HashMap, + in_constructor: bool, } impl<'a> CppTranslator<'a> { @@ -27,6 +31,8 @@ impl<'a> CppTranslator<'a> { file: file, current_scope: 0, prev_scope: 0, + class_members: HashMap::new(), + in_constructor: false, } } @@ -106,7 +112,10 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { Statement::Assert(a) => self.visit_assert(a), Statement::Pass(p) => self.visit_pass(p), Statement::Delete(d) => self.visit_delete(d), - Statement::ReturnStmt(r) => self.visit_return(r), + Statement::ReturnStmt(r) => { + self.visit_return(r); + writeln!(self.output, ";"); + }, Statement::Raise(r) => self.visit_raise(r), Statement::BreakStmt(b) => self.visit_break(b), Statement::ContinueStmt(c) => self.visit_continue(c), @@ -210,6 +219,14 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { }; self.visit_name(n); }, + Expression::Attribute(attr) => { + if let Expression::Name(n) = &attr.value { + if n.id == "self" { + self.class_members.insert(attr.attr.clone(), self.python_type_to_cpp(&self.checker.get_type(&a.value.get_node()))); + } + } + self.visit_expr(target); + } _ => { self.visit_expr(target); } @@ -230,16 +247,48 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } fn visit_call(&mut self, c: &Call) { - let typ = self.checker.get_type(&c.func.get_node()); + let mut typ = self.checker.get_type(&c.func.get_node()); self.visit_expr(&c.func); write!(self.output, "("); + // In case c.func is a class instance, we need to use the __call__ method + // of that instance instead -- we fix this here. + if let PythonType::Instance(i) = &typ { + let symbol_table = self.checker.get_symbol_table(None); + typ = self.checker.type_evaluator.lookup_on_class(&symbol_table, &i.class_type, "__call__").expect("instance type not callable").clone(); + let PythonType::Callable(old_callable) = typ else { + panic!("XXX"); + }; + let callable_type = types::CallableType::new( + old_callable.name, + old_callable.signature[1..].to_vec(), + old_callable.return_type, + old_callable.is_async, + ); + typ = PythonType::Callable(Box::new(callable_type)); + } + // In case c.func is a class, we need to use the type signature of the + // __init__ method. + if let PythonType::Class(c) = &typ { + let symbol_table = self.checker.get_symbol_table(None); + typ = self.checker.type_evaluator.lookup_on_class(&symbol_table, &c, "__init__").expect("class currently needs an __init__ method").clone(); + let PythonType::Callable(old_callable) = typ else { + panic!("XXX"); + }; + let callable_type = types::CallableType::new( + old_callable.name, + old_callable.signature[1..].to_vec(), + old_callable.return_type, + old_callable.is_async, + ); + typ = PythonType::Callable(Box::new(callable_type)); + } match typ { PythonType::Callable(callable) => { let mut num_pos_args = 0; // First check all the positional args for (i, arg) in callable.signature.iter().enumerate() { match arg { - types::CallableArgs::Args(t) => { + types::CallableArgs::Positional(t) => { self.check_type(&c.args[i].get_node(), t); if i != 0 { write!(self.output, ", "); @@ -247,17 +296,17 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { self.visit_expr(&c.args[i]); num_pos_args = i; }, - types::CallableArgs::Positional(t) => { + types::CallableArgs::Args(t) => { break; }, _ => {} } } // Then check all the star args if there are any - if num_pos_args < c.args.len() { + if num_pos_args + 1 < c.args.len() { write!(self.output, "{{"); for (i, arg) in c.args[num_pos_args..].iter().enumerate() { - self.check_type(&arg.get_node(), callable.signature[num_pos_args].get_type()); + self.check_type(&arg.get_node(), callable.signature[num_pos_args+i].get_type()); if i != 0 { write!(self.output, ", "); } @@ -299,12 +348,31 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { write!(self.output, ".{}", attribute.attr); } - fn visit_return(&mut self, _r: &Return) { + fn visit_return(&mut self, r: &Return) { + write!(self.output, "return "); + if let Some(value) = &r.value { + self.visit_expr(value); + } } fn visit_function_def(&mut self, f: &Arc) { self.enter_scope(f.node.start); - write!(self.output, "void {}(", intern_lookup(f.name)); + let mut name = intern_lookup(f.name).to_string(); + if name == "__init__" { + // In this case, the function is a constructor and in + // C++ needs to be named the same as the class. We achieve + // this by naming it after the type of the "self" argument + // of __init__. + name = self.python_type_to_cpp(&self.checker.get_type(&f.args.args[0].node)); + self.class_members = HashMap::new(); + self.in_constructor = true; + } + if let Some(ret) = &f.returns { + let return_type = self.python_type_to_cpp(&self.checker.get_type(&ret.get_node())); + write!(self.output, "{} {}(", return_type, name); + } else { + write!(self.output, "void {}(", name); + } for (i, arg) in f.args.args.iter().enumerate() { if i != 0 { write!(self.output, ", "); @@ -313,10 +381,37 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } writeln!(self.output, ") {{"); self.indent_level += 1; + // If this is an instance method, introduce "self" + self.write_indent(); + writeln!(self.output, "auto& self = *this;"); for stmt in &f.body { self.visit_stmt(stmt); } self.indent_level -= 1; + self.write_indent(); + writeln!(self.output, "}}"); + self.leave_scope(); + } + + fn visit_class_def(&mut self, c: &Arc) { + let name = intern_lookup(c.name); + writeln!(self.output, "class {} {{", name); + self.enter_scope(c.node.start); + self.indent_level += 1; + for stmt in &c.body { + self.visit_stmt(stmt); + } + self.indent_level -= 1; + // print class member variables + self.write_indent(); + writeln!(self.output, "private:"); + // TODO: Want to move this out, not clone it + for (key, value) in self.class_members.clone() { + self.write_indent(); + writeln!(self.output, " {} {};", value, key); + } + self.class_members = HashMap::new(); + self.write_indent(); writeln!(self.output, "}}"); self.leave_scope(); } diff --git a/typechecker/src/checker.rs b/typechecker/src/checker.rs index f37aee9f..e17e861f 100644 --- a/typechecker/src/checker.rs +++ b/typechecker/src/checker.rs @@ -16,7 +16,7 @@ use rust_lapper::{Interval, Lapper}; pub struct TypeChecker<'a> { pub types: Lapper, id: Id, - type_evaluator: TypeEvaluator<'a>, + pub type_evaluator: TypeEvaluator<'a>, build_manager: &'a BuildManager, current_scope: u32, prev_scope: u32, diff --git a/typechecker/src/type_evaluator.rs b/typechecker/src/type_evaluator.rs index 50d0c748..1a5a0123 100755 --- a/typechecker/src/type_evaluator.rs +++ b/typechecker/src/type_evaluator.rs @@ -60,6 +60,13 @@ bitflags::bitflags! { } } +fn class_type_to_instance_type(class_type: PythonType) -> PythonType { + let PythonType::Class(c) = class_type else { + return PythonType::Unknown; + }; + PythonType::Instance(types::InstanceType::new(c.clone(), [].to_vec())) +} + /// Struct for evaluating the type of an expression impl<'a> TypeEvaluator<'a> { pub fn new(build_manager: &'a BuildManager) -> Self { @@ -85,12 +92,12 @@ impl<'a> TypeEvaluator<'a> { let typ = match &c.value { // Constants are not literals unless they are explicitly // typing.readthedocs.io/en/latest/spec/literal.html#backwards-compatibility - ast::ConstantValue::Int => self.get_builtin_type("int"), - ast::ConstantValue::Float => self.get_builtin_type("float"), - ast::ConstantValue::Str(_) => self.get_builtin_type("str"), - ast::ConstantValue::Bool(_) => self.get_builtin_type("bool"), + ast::ConstantValue::Int => self.get_builtin_type("int").map(class_type_to_instance_type), + ast::ConstantValue::Float => self.get_builtin_type("float").map(class_type_to_instance_type), + ast::ConstantValue::Str(_) => self.get_builtin_type("str").map(class_type_to_instance_type), + ast::ConstantValue::Bool(_) => self.get_builtin_type("bool").map(class_type_to_instance_type), ast::ConstantValue::None => Some(PythonType::None), - ast::ConstantValue::Bytes => self.get_builtin_type("bytes"), + ast::ConstantValue::Bytes => self.get_builtin_type("bytes").map(class_type_to_instance_type), ast::ConstantValue::Ellipsis => Some(PythonType::Any), // TODO: implement ast::ConstantValue::Tuple => Some(PythonType::Unknown), @@ -123,8 +130,21 @@ impl<'a> TypeEvaluator<'a> { scope_id, ); Ok(return_type) + } else if let PythonType::Instance(i) = &called_type { + // This executes the __call__ method of the instance + let Some(PythonType::Callable(c)) = self.lookup_on_class(symbol_table, &i.class_type, "__call__") else { + bail!("If you call an instance, it must have a __call__ method"); + }; + let return_type = self.get_return_type_of_callable( + &c, + &call.args, + symbol_table, + scope_id, + ); + Ok(return_type) } else if let PythonType::Class(c) = &called_type { - Ok(called_type) + // This instantiates the class + Ok(PythonType::Instance(types::InstanceType::new(c.clone(), [].to_vec()))) } else if let PythonType::TypeVar(t) = &called_type { let Some(first_arg) = call.args.first() else { bail!("TypeVar must be called with a name"); @@ -552,7 +572,7 @@ impl<'a> TypeEvaluator<'a> { let expr_type = match type_annotation { Expression::Name(name) => { // TODO: Reject this type if the name refers to a variable. - self.get_name_type(&name.id, Some(name.node.start), symbol_table, scope_id) + return class_type_to_instance_type(self.get_name_type(&name.id, Some(name.node.start), symbol_table, scope_id)); } Expression::Constant(ref c) => match c.value { ast::ConstantValue::None => PythonType::None, @@ -684,7 +704,7 @@ impl<'a> TypeEvaluator<'a> { let Some(typ) = self.lookup_on_module(symbol_table, scope_id, &n.id, &a.attr) else { return PythonType::Unknown; }; - return typ; + return class_type_to_instance_type(typ); }, _ => todo!(), }; @@ -1532,7 +1552,7 @@ impl<'a> TypeEvaluator<'a> { ret_type } - fn lookup_on_class( + pub fn lookup_on_class( &self, symbol_table: &SymbolTable, c: &ClassType, From 51fa6d50f57476a5ef145c63110688bb5110ddb5 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 7 Dec 2024 16:09:10 -0800 Subject: [PATCH 07/14] small cleanups --- translator/src/translator.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/translator/src/translator.rs b/translator/src/translator.rs index 72355fa7..d61ee594 100644 --- a/translator/src/translator.rs +++ b/translator/src/translator.rs @@ -2,9 +2,7 @@ use std::sync::Arc; use std::collections::HashMap; use enderpy_python_parser::ast::{self, *}; use enderpy_python_parser::parser::parser::intern_lookup; -use std::error::Error; use std::fmt::Write; -use log::warn; use enderpy_python_type_checker::{types, ast_visitor::TraversalVisitor, file::EnderpyFile, checker::TypeChecker, types::PythonType}; use enderpy_python_type_checker::{get_module_name, symbol_table}; @@ -296,7 +294,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { self.visit_expr(&c.args[i]); num_pos_args = i; }, - types::CallableArgs::Args(t) => { + types::CallableArgs::Args(_t) => { break; }, _ => {} @@ -333,7 +331,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { match symbol_table.lookup_in_scope(&n.id, self.current_scope) { Some(entry) => { match entry.last_declaration() { - symbol_table::Declaration::Alias(a) => { + symbol_table::Declaration::Alias(_a) => { write!(self.output, "::{}", attribute.attr); return }, From c12f64204cb01ad6b3e38920ff5ea7065395f56c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 7 Dec 2024 16:11:43 -0800 Subject: [PATCH 08/14] cleanup --- typechecker/src/ast_visitor.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/typechecker/src/ast_visitor.rs b/typechecker/src/ast_visitor.rs index f7101fc2..e35258ff 100644 --- a/typechecker/src/ast_visitor.rs +++ b/typechecker/src/ast_visitor.rs @@ -323,7 +323,7 @@ pub trait TraversalVisitor { } fn visit_assign(&mut self, _a: &Assign) { - // todo!() + todo!() } fn visit_ann_assign(&mut self, _a: &AnnAssign) { @@ -339,7 +339,7 @@ pub trait TraversalVisitor { } fn visit_pass(&mut self, _p: &Pass) { - // todo!() + todo!() } fn visit_delete(&mut self, _d: &Delete) { From eed7da85b0062789e8a9451ccd8683017fc62b20 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 7 Dec 2024 17:37:14 -0800 Subject: [PATCH 09/14] small fixes to make the example compile --- translator/src/translator.rs | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/translator/src/translator.rs b/translator/src/translator.rs index d61ee594..779bc392 100644 --- a/translator/src/translator.rs +++ b/translator/src/translator.rs @@ -292,19 +292,18 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { write!(self.output, ", "); } self.visit_expr(&c.args[i]); - num_pos_args = i; + num_pos_args = num_pos_args + 1; }, - types::CallableArgs::Args(_t) => { + _ => { break; - }, - _ => {} + } } } // Then check all the star args if there are any - if num_pos_args + 1 < c.args.len() { + if c.args.len() > num_pos_args { write!(self.output, "{{"); for (i, arg) in c.args[num_pos_args..].iter().enumerate() { - self.check_type(&arg.get_node(), callable.signature[num_pos_args+i].get_type()); + self.check_type(&arg.get_node(), callable.signature[num_pos_args].get_type()); if i != 0 { write!(self.output, ", "); } @@ -369,7 +368,11 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { let return_type = self.python_type_to_cpp(&self.checker.get_type(&ret.get_node())); write!(self.output, "{} {}(", return_type, name); } else { - write!(self.output, "void {}(", name); + if self.in_constructor { + write!(self.output, "{}(", name); + } else { + write!(self.output, "void {}(", name); + } } for (i, arg) in f.args.args.iter().enumerate() { if i != 0 { @@ -388,12 +391,15 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { self.indent_level -= 1; self.write_indent(); writeln!(self.output, "}}"); + self.in_constructor = false; self.leave_scope(); } fn visit_class_def(&mut self, c: &Arc) { let name = intern_lookup(c.name); writeln!(self.output, "class {} {{", name); + self.write_indent(); + writeln!(self.output, "public:"); self.enter_scope(c.node.start); self.indent_level += 1; for stmt in &c.body { @@ -410,7 +416,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } self.class_members = HashMap::new(); self.write_indent(); - writeln!(self.output, "}}"); + writeln!(self.output, "}};"); self.leave_scope(); } From eba7504951a6849070627fc71118df239a1c7b06 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 7 Dec 2024 18:10:28 -0800 Subject: [PATCH 10/14] fix self arg handling --- translator/src/translator.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/translator/src/translator.rs b/translator/src/translator.rs index 779bc392..b1130911 100644 --- a/translator/src/translator.rs +++ b/translator/src/translator.rs @@ -374,7 +374,13 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { write!(self.output, "void {}(", name); } } - for (i, arg) in f.args.args.iter().enumerate() { + // Filter out "self" arg (first arg of a Python method), + // since in C++ the "this" arg is implicit. + // TODO: This will also filter out random args called "self" -- + // instead we should check if we are in a class definition and then + // only filter the first argument called "self". + let args = f.args.args.iter().filter(|arg| arg.arg != "self"); + for (i, arg) in args.enumerate() { if i != 0 { write!(self.output, ", "); } From db6890790bfca58c300e7ccf1b73c2dae87e243c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 8 Dec 2024 00:54:40 -0800 Subject: [PATCH 11/14] remove write! calls --- translator/src/translator.rs | 96 +++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 46 deletions(-) diff --git a/translator/src/translator.rs b/translator/src/translator.rs index b1130911..eef340a3 100644 --- a/translator/src/translator.rs +++ b/translator/src/translator.rs @@ -1,8 +1,8 @@ use std::sync::Arc; use std::collections::HashMap; + use enderpy_python_parser::ast::{self, *}; use enderpy_python_parser::parser::parser::intern_lookup; -use std::fmt::Write; use enderpy_python_type_checker::{types, ast_visitor::TraversalVisitor, file::EnderpyFile, checker::TypeChecker, types::PythonType}; use enderpy_python_type_checker::{get_module_name, symbol_table}; @@ -41,7 +41,11 @@ impl<'a> CppTranslator<'a> { } pub fn write_indent(&mut self) { - write!(self.output, "{}", " ".repeat(self.indent_level)); + self.emit(" ".repeat(self.indent_level)); + } + + fn emit>(&mut self, s: S) { + self.output += s.as_ref(); } fn check_type(&self, node: &Node, typ: &PythonType) { @@ -103,7 +107,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { Statement::ImportFrom(i) => self.visit_import_from(i), Statement::AssignStatement(a) => { self.visit_assign(a); - writeln!(self.output, ";"); + self.emit(";\n"); }, Statement::AnnAssignStatement(a) => self.visit_ann_assign(a), Statement::AugAssignStatement(a) => self.visit_aug_assign(a), @@ -112,7 +116,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { Statement::Delete(d) => self.visit_delete(d), Statement::ReturnStmt(r) => { self.visit_return(r); - writeln!(self.output, ";"); + self.emit(";\n"); }, Statement::Raise(r) => self.visit_raise(r), Statement::BreakStmt(b) => self.visit_break(b), @@ -169,15 +173,15 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { fn visit_constant(&mut self, constant: &Constant) { match constant.value { - ConstantValue::None => write!(self.output, "None"), - ConstantValue::Ellipsis => write!(self.output, "..."), - ConstantValue::Bool(_) => write!(self.output, "bool"), - ConstantValue::Str(_) => write!(self.output, "\"{}\"", constant.get_value(&self.file.source).to_string()), - ConstantValue::Bytes => write!(self.output, "bytes"), - ConstantValue::Tuple => write!(self.output, "tuple"), - ConstantValue::Int => write!(self.output, "{}", constant.get_value(&self.file.source).to_string()), - ConstantValue::Float => write!(self.output, "{}", constant.get_value(&self.file.source).to_string()), - ConstantValue::Complex => write!(self.output, "complex"), + ConstantValue::None => self.emit("None"), + ConstantValue::Ellipsis => self.emit("..."), + ConstantValue::Bool(_) => self.emit("bool"), + ConstantValue::Str(_) => self.emit(constant.get_value(&self.file.source)), + ConstantValue::Bytes => self.emit("bytes"), + ConstantValue::Tuple => self.emit("tuple"), + ConstantValue::Int => self.emit(constant.get_value(&self.file.source)), + ConstantValue::Float => self.emit(constant.get_value(&self.file.source)), + ConstantValue::Complex => self.emit("complex"), /* Constant::Tuple(elements) => { let tuple_elements: Vec = elements @@ -193,7 +197,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { fn visit_import(&mut self, import: &Import) { for name in import.names.iter() { if name.name == "torch" { - writeln!(self.output, "#include "); + self.emit("#include \n"); } } } @@ -210,7 +214,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { let path = node.declarations[0].declaration_path(); if path.node == n.node { let typ = self.checker.get_type(&n.node); - write!(self.output, "{} ", self.python_type_to_cpp(&typ)); + self.emit(self.python_type_to_cpp(&typ)); } }, None => {}, @@ -230,24 +234,24 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } } } - write!(self.output, " = "); + self.emit(" = "); self.visit_expr(&a.value); } fn visit_name(&mut self, name: &Name) { - write!(self.output, "{}", name.id); + self.emit(name.id.clone()); } fn visit_bin_op(&mut self, b: &BinOp) { self.visit_expr(&b.left); - write!(self.output, " {} ", &b.op); + self.emit(b.op.to_string()); self.visit_expr(&b.right); } fn visit_call(&mut self, c: &Call) { let mut typ = self.checker.get_type(&c.func.get_node()); self.visit_expr(&c.func); - write!(self.output, "("); + self.emit("("); // In case c.func is a class instance, we need to use the __call__ method // of that instance instead -- we fix this here. if let PythonType::Instance(i) = &typ { @@ -289,7 +293,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { types::CallableArgs::Positional(t) => { self.check_type(&c.args[i].get_node(), t); if i != 0 { - write!(self.output, ", "); + self.emit(", "); } self.visit_expr(&c.args[i]); num_pos_args = num_pos_args + 1; @@ -301,15 +305,15 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } // Then check all the star args if there are any if c.args.len() > num_pos_args { - write!(self.output, "{{"); + self.emit("{"); for (i, arg) in c.args[num_pos_args..].iter().enumerate() { self.check_type(&arg.get_node(), callable.signature[num_pos_args].get_type()); if i != 0 { - write!(self.output, ", "); + self.emit(", "); } self.visit_expr(arg); } - write!(self.output, "}}"); + self.emit("}"); } }, _ => { @@ -319,7 +323,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { // for keyword in &c.keywords { // self.visit_expr(&keyword.value); // } - write!(self.output, ")"); + self.emit(")"); } fn visit_attribute(&mut self, attribute: &Attribute) { @@ -331,7 +335,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { Some(entry) => { match entry.last_declaration() { symbol_table::Declaration::Alias(_a) => { - write!(self.output, "::{}", attribute.attr); + self.emit(format!("::{}", &attribute.attr)); return }, _ => {} @@ -342,11 +346,11 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { }, _ => {} } - write!(self.output, ".{}", attribute.attr); + self.emit(format!(".{}", attribute.attr)); } fn visit_return(&mut self, r: &Return) { - write!(self.output, "return "); + self.emit("return "); if let Some(value) = &r.value { self.visit_expr(value); } @@ -366,12 +370,12 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } if let Some(ret) = &f.returns { let return_type = self.python_type_to_cpp(&self.checker.get_type(&ret.get_node())); - write!(self.output, "{} {}(", return_type, name); + self.emit(format!("{} {}(", return_type, name)); } else { if self.in_constructor { - write!(self.output, "{}(", name); + self.emit(format!("{}(", name)); } else { - write!(self.output, "void {}(", name); + self.emit(format!("void {}(", name)); } } // Filter out "self" arg (first arg of a Python method), @@ -382,30 +386,30 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { let args = f.args.args.iter().filter(|arg| arg.arg != "self"); for (i, arg) in args.enumerate() { if i != 0 { - write!(self.output, ", "); + self.emit(", "); } - write!(self.output, "{} {}", self.python_type_to_cpp(&self.checker.get_type(&arg.node)), arg.arg); + self.emit(format!("{} {}", self.python_type_to_cpp(&self.checker.get_type(&arg.node)), arg.arg)); } - writeln!(self.output, ") {{"); + self.emit(") {\n"); self.indent_level += 1; // If this is an instance method, introduce "self" self.write_indent(); - writeln!(self.output, "auto& self = *this;"); + self.emit("auto& self = *this;\n"); for stmt in &f.body { self.visit_stmt(stmt); } self.indent_level -= 1; self.write_indent(); - writeln!(self.output, "}}"); + self.emit("}\n"); self.in_constructor = false; self.leave_scope(); } fn visit_class_def(&mut self, c: &Arc) { let name = intern_lookup(c.name); - writeln!(self.output, "class {} {{", name); + self.emit(format!("class {} {{\n", name)); self.write_indent(); - writeln!(self.output, "public:"); + self.emit("public:\n"); self.enter_scope(c.node.start); self.indent_level += 1; for stmt in &c.body { @@ -414,15 +418,15 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { self.indent_level -= 1; // print class member variables self.write_indent(); - writeln!(self.output, "private:"); + self.emit("private:\n"); // TODO: Want to move this out, not clone it for (key, value) in self.class_members.clone() { self.write_indent(); - writeln!(self.output, " {} {};", value, key); + self.emit(format!(" {} {};\n", value, key)); } self.class_members = HashMap::new(); self.write_indent(); - writeln!(self.output, "}};"); + self.emit("};\n"); self.leave_scope(); } @@ -441,22 +445,22 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { }, _ => {} } - write!(self.output, "for(int "); + self.emit("for(int "); self.visit_expr(&f.target); - write!(self.output, " = 0; "); + self.emit(" = 0; "); self.visit_expr(&f.target); - write!(self.output, " < "); + self.emit(" < "); self.visit_expr(&bound.unwrap()); - write!(self.output, "; ++"); + self.emit("; ++"); self.visit_expr(&f.target); - writeln!(self.output, ") {{"); + self.emit(") {\n"); self.indent_level += 1; for stmt in &f.body { self.visit_stmt(stmt); } self.indent_level -= 1; self.write_indent(); - writeln!(self.output, "}}"); + self.emit("}\n"); } } From e4fe344063f2faf40aebf12b1bbaed38e783ab50 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 8 Dec 2024 01:21:19 -0800 Subject: [PATCH 12/14] cleanup --- translator/src/translator.rs | 45 ++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/translator/src/translator.rs b/translator/src/translator.rs index eef340a3..9950d989 100644 --- a/translator/src/translator.rs +++ b/translator/src/translator.rs @@ -40,12 +40,22 @@ impl<'a> CppTranslator<'a> { } } - pub fn write_indent(&mut self) { + fn emit>(&mut self, s: S) { + self.output += s.as_ref(); + } + + fn emit_indent(&mut self) { self.emit(" ".repeat(self.indent_level)); } - fn emit>(&mut self, s: S) { - self.output += s.as_ref(); + fn emit_type(&mut self, node: &ast::Node) { + let cpp_type = self.get_cpp_type(node); + self.emit(cpp_type); + } + + fn get_cpp_type(&mut self, node: &ast::Node) -> String { + let typ = self.checker.get_type(node); + return self.python_type_to_cpp(&typ); } fn check_type(&self, node: &Node, typ: &PythonType) { @@ -100,7 +110,7 @@ impl<'a> CppTranslator<'a> { impl<'a> TraversalVisitor for CppTranslator<'a> { fn visit_stmt(&mut self, s: &ast::Statement) { - self.write_indent(); + self.emit_indent(); match s { Statement::ExpressionStatement(e) => self.visit_expr(e), Statement::Import(i) => self.visit_import(i), @@ -213,8 +223,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { Some(node) => { let path = node.declarations[0].declaration_path(); if path.node == n.node { - let typ = self.checker.get_type(&n.node); - self.emit(self.python_type_to_cpp(&typ)); + self.emit_type(&n.node); } }, None => {}, @@ -224,7 +233,8 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { Expression::Attribute(attr) => { if let Expression::Name(n) = &attr.value { if n.id == "self" { - self.class_members.insert(attr.attr.clone(), self.python_type_to_cpp(&self.checker.get_type(&a.value.get_node()))); + let cpp_type = self.get_cpp_type(&a.value.get_node()); + self.class_members.insert(attr.attr.clone(), cpp_type); } } self.visit_expr(target); @@ -364,12 +374,12 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { // C++ needs to be named the same as the class. We achieve // this by naming it after the type of the "self" argument // of __init__. - name = self.python_type_to_cpp(&self.checker.get_type(&f.args.args[0].node)); + name = self.get_cpp_type(&f.args.args[0].node); self.class_members = HashMap::new(); self.in_constructor = true; } if let Some(ret) = &f.returns { - let return_type = self.python_type_to_cpp(&self.checker.get_type(&ret.get_node())); + let return_type = self.get_cpp_type(&ret.get_node()); self.emit(format!("{} {}(", return_type, name)); } else { if self.in_constructor { @@ -388,18 +398,19 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { if i != 0 { self.emit(", "); } - self.emit(format!("{} {}", self.python_type_to_cpp(&self.checker.get_type(&arg.node)), arg.arg)); + self.emit_type(&arg.node); + self.emit(format!(" {}", arg.arg)); } self.emit(") {\n"); self.indent_level += 1; // If this is an instance method, introduce "self" - self.write_indent(); + self.emit_indent(); self.emit("auto& self = *this;\n"); for stmt in &f.body { self.visit_stmt(stmt); } self.indent_level -= 1; - self.write_indent(); + self.emit_indent(); self.emit("}\n"); self.in_constructor = false; self.leave_scope(); @@ -408,7 +419,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { fn visit_class_def(&mut self, c: &Arc) { let name = intern_lookup(c.name); self.emit(format!("class {} {{\n", name)); - self.write_indent(); + self.emit_indent(); self.emit("public:\n"); self.enter_scope(c.node.start); self.indent_level += 1; @@ -417,15 +428,15 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { } self.indent_level -= 1; // print class member variables - self.write_indent(); + self.emit_indent(); self.emit("private:\n"); // TODO: Want to move this out, not clone it for (key, value) in self.class_members.clone() { - self.write_indent(); + self.emit_indent(); self.emit(format!(" {} {};\n", value, key)); } self.class_members = HashMap::new(); - self.write_indent(); + self.emit_indent(); self.emit("};\n"); self.leave_scope(); } @@ -459,7 +470,7 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { self.visit_stmt(stmt); } self.indent_level -= 1; - self.write_indent(); + self.emit_indent(); self.emit("}\n"); } } From 0b6d0a9d68d184515d85be2e49e7718e6e5e5123 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 8 Dec 2024 01:32:18 -0800 Subject: [PATCH 13/14] cleanups --- translator/src/translator.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/translator/src/translator.rs b/translator/src/translator.rs index 9950d989..87efe458 100644 --- a/translator/src/translator.rs +++ b/translator/src/translator.rs @@ -15,8 +15,10 @@ pub struct CppTranslator<'a> { file: &'a EnderpyFile, current_scope: u32, prev_scope: u32, - // Member variables of the current class + // Member variables of the current class, map from name to type class_members: HashMap, + // Whether we are currently inside a __init__ method + // and therefore need to record member variables in_constructor: bool, } @@ -49,11 +51,10 @@ impl<'a> CppTranslator<'a> { } fn emit_type(&mut self, node: &ast::Node) { - let cpp_type = self.get_cpp_type(node); - self.emit(cpp_type); + self.emit(self.get_cpp_type(node)); } - fn get_cpp_type(&mut self, node: &ast::Node) -> String { + fn get_cpp_type(&self, node: &ast::Node) -> String { let typ = self.checker.get_type(node); return self.python_type_to_cpp(&typ); } @@ -215,13 +216,13 @@ impl<'a> TraversalVisitor for CppTranslator<'a> { fn visit_assign(&mut self, a: &Assign) { let symbol_table = self.checker.get_symbol_table(None); for target in &a.targets { - // let type = self.checker.types. match target { Expression::Name(n) => { let node = symbol_table.lookup_in_scope(&n.id, self.current_scope); match node { Some(node) => { let path = node.declarations[0].declaration_path(); + // If this is the place where the name was defined, also emit its type if path.node == n.node { self.emit_type(&n.node); } From 45566a0c6e52397279453b59e616ccb6da3b67ab Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 8 Dec 2024 14:59:43 -0800 Subject: [PATCH 14/14] formatting --- translator/Cargo.toml | 2 +- translator/src/lib.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/translator/Cargo.toml b/translator/Cargo.toml index 920036c6..8f211b0e 100644 --- a/translator/Cargo.toml +++ b/translator/Cargo.toml @@ -6,4 +6,4 @@ edition = "2021" [dependencies] enderpy_python_parser = { path = "../parser", version = "0.1.0" } enderpy_python_type_checker = {path = "../typechecker", version = "0.1.0" } -log = { version = "0.4.17" } \ No newline at end of file +log = { version = "0.4.17" } diff --git a/translator/src/lib.rs b/translator/src/lib.rs index a5aac9df..a73a1679 100644 --- a/translator/src/lib.rs +++ b/translator/src/lib.rs @@ -1 +1 @@ -pub mod translator; \ No newline at end of file +pub mod translator;