Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5,682 changes: 5,682 additions & 0 deletions data/infix/generated_expressions.csv

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions data/infix/small_test_expressions.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
(max(min(min(1999, min(v2 + v3, 2000) + -1), ((((v4*1008) + v5) + -10) + (min(v6 - (v4*1008), 998) + 18)) - 1) + 1, min(max(max(0, max(v3, 0)), ((v4*1008) + v5) + -10), (min(v6 - (v4*1008), 998) + 18) + (((v4*1008) + v5) + -10))) <= min(max(max(0, max(v3, 0)), ((v4*1008) + v5) + -10), (min(v6 - (v4*1008), 998) + 18) + (((v4*1008) + v5) + -10)));0
(((((((uint1)1 && (((((((v0*32) + v1) + v2) + v3) + -15) % 16) <= ((((((v0*32) + v1) + v2) + v3) + -15) % 16))) && (((((((v0*32) + v1) + v2) + v3) + -15) % 16) >= ((((((v0*32) + v1) + v2) + v3) + -15) % 16))) && (((((v4 + v5)*4) + 0) + -7) <= (((v4 + v5)*4) + -7))) && (((((v4 + v5)*4) + 3) + -7) >= ((((v4 + v5)*4) + 3) + -7))) && (v6 <= v6)) && (v6 >= v6));1
(((((((uint1)1 && ((((v0*4) + v1) + 0) <= ((v0*4) + v1))) && ((((v0*4) + v1) + 3) >= (((v0*4) + v1) + 3))) && ((v2/4) <= (v2/4))) && (((v2/4) + 32) >= ((v2/4) + 32))) && (v3 <= v3)) && (((v4 + v3) + -1) >= ((v4 + v3) - 1)));1
(max(0, min(1, (0 + (((v0 % 4) + 47)/16)) - 1) + 1) <= 0);0
(0 >= ((0 + ((max(v0, 0) + 138)/4)) - 1));0
(((((((uint1)1 && (((((0*4) + (max((v2*2) + v3, (min(((v4 + v5) + -1)/250, v2 + 3)*2) + ((max(min(v2 - (((v4 + v5) + -1)/250), 0), -3)*2) + v3))*4)) + 0) + -5) <= (((v2*8) + ((v3*4) + max(max(min(v2 - (((v4 + v5) + -1)/250), 0), -3)*8, 0))) + -5))) && (((((((min(min((((v4 + v5) + -1)/250) - v2, 3)*8, -3) + 3)/4)*4) + max(((v2*2) + v3)*4, (((min(((v4 + v5) + -1)/250, v2 + 3)*2) + v3)*4) + 3)) + 3) + -5) >= (((min(((v4 + v5) + -1)/250, v2 + 3)*8) + ((v3*4) + 3)) + -5))) && (((((v6/((v7/4) + 1)) + (v8/250))*8) + -5) <= ((((v6/((v7/4) + 1)) + (v8/250))*8) + -5))) && (((((v6/((v7/4) + 1)) + (v8/250))*8) + ((8 + -5) - 1)) >= ((((v6/((v7/4) + 1)) + (v8/250))*8) + ((8 + -5) - 1)))) && (v9 <= v9)) && (v9 >= v9));0
(((v0*16) + -36) == ((((v0*16) + -36) + ((min(v1, -1)*16) + 48)) - 1));0
(0 == ((0 + (18 - (max(v0, 5)*3))) - 1));0
((0 - ((v0 + v1)*2)) > (499 - ((v0 + v1)*2)));0
(max(min((183 - v0)/64, (0 + 2) - 1) + 1, min(max(0 - (v0/64), 0), 2 + 0)) <= min(max(0 - (v0/64), 0), 2 + 0));0
((((v0*2) + v1) + 1) < (max((v0 + 1)*2, -1) + v1));1
(((((((uint1)1 && ((0 + v0) <= (0 + v0))) && ((3 + v0) >= (((4 + 0) - 1) + v0))) && ((v1 + v2) <= (v1 + v2))) && ((v1 + v2) >= (v1 + v2))) && (v3 <= v3)) && (v3 >= v3));1
(((((v0 % ((v1 + 180)/168))*-168)/4)*4) == ((v0 % ((v1 + 180)/168))*-168));1
((((v0*-1000)/4)*4) == (v0*-1000));1
((v0 + -5) >= (((v0 - 1) + 1) + 1));0
130 changes: 67 additions & 63 deletions src/dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ pub fn minimal_set_to_prove(
}
}

pub fn generation_execution(file_path: &OsString, params: (usize, usize, u64),reorder_count: usize, batch_size: usize){
pub fn generation_execution(file_path: &OsString, params: (usize, usize, u64),reorder_count: usize, batch_size: usize, continue_from_expr: usize){
let mut expressions_vect = Vec::new();
let file = File::open(file_path).unwrap();
//let mut rdr = csv::Reader::from_reader(file);
Expand All @@ -211,15 +211,21 @@ pub fn generation_execution(file_path: &OsString, params: (usize, usize, u64),re
let mut i = 0;
for result in rdr.records() {
i += 1;
let record = result.unwrap();
let expression = &record[1];
expressions_vect.push(expression.to_string());
if i % batch_size == 0{
generate_dataset_0_1_par(&expressions_vect, -2, params, true, reorder_count, i/batch_size);
println!("{} expressions processed!", i);
expressions_vect = Vec::new();
if i > continue_from_expr{
let record = result.unwrap();
let expression = &record[1];
expressions_vect.push(expression.to_string());
if i % batch_size == 0{
generate_dataset_0_1_par(&expressions_vect, -2, params, true, reorder_count, i/batch_size);
println!("{} expressions processed!", i);
expressions_vect = Vec::new();
}
}
}
if expressions_vect.len() > 0 {
generate_dataset_0_1_par(&expressions_vect, -2, params, true, reorder_count, i/batch_size + 1);
println!("{} expressions processed!", i);
}
}

#[allow(dead_code)]
Expand Down Expand Up @@ -266,64 +272,62 @@ pub fn minimal_set_to_prove_0_1(
let end_0: Pattern<Math> = "0".parse().unwrap();
let goals = [end_0.clone(), end_1.clone()];
let mut proved_goal = "0/1".to_string();
let mut proved_once: bool = false;
let mut runner;
let mut id;
let mut matches;
let mut i: usize;
let mut counter: usize;
// let mut minimal_ruleset_len: usize;
let mut rule;
let mut ruleset = rules(ruleset_id);
let data_object;
ruleset.shuffle(&mut rng);
//println!("Ruleset size == {}", ruleset.len());
let mut ruleset_copy: Vec<egg::Rewrite<Math, ConstantFold>>;
let mut ruleset_minimal: Vec<egg::Rewrite<Math, ConstantFold>>;
let ruleset_copy_names: Vec<String>;
counter = 0;
ruleset_minimal = ruleset.clone();
while counter < reorder_count {
ruleset_copy = ruleset.clone();
ruleset_copy.shuffle(&mut rng);
i = 0;
while i < ruleset_copy.len() {
rule = ruleset_copy.remove(i);
start = expression.parse().unwrap();
// end = expression.1.parse().unwrap();
runner = Runner::default()
.with_iter_limit(params.0)
.with_node_limit(params.1)
.with_time_limit(Duration::new(params.2, 0))
.with_expr(&start);

if use_iteration_check {
runner = runner.run_check_iteration(ruleset_copy.iter(), &goals);
} else {
runner = runner.run(ruleset_copy.iter());
}
id = runner.egraph.find(*runner.roots.last().unwrap());
// matches = end.search_eclass(&runner.egraph, id);
matches = goals.iter().all(|goal| {
let mat = goal.search_eclass(&runner.egraph, id);
if !mat.is_none() {
proved_goal = goal.to_string();
let result = crate::trs::prove(expression, -2, params, true, false);
if result.result{
let mut runner;
let mut id;
let mut matches;
let mut i: usize;
let mut counter: usize;
// let mut minimal_ruleset_len: usize;
let mut rule;
let mut ruleset = rules(ruleset_id);
let data_object;
ruleset.shuffle(&mut rng);
//println!("Ruleset size == {}", ruleset.len());
let mut ruleset_copy: Vec<egg::Rewrite<Math, ConstantFold>>;
let mut ruleset_minimal: Vec<egg::Rewrite<Math, ConstantFold>>;
let ruleset_copy_names: Vec<String>;
counter = 0;
ruleset_minimal = ruleset.clone();
while counter < reorder_count {
ruleset_copy = ruleset.clone();
ruleset_copy.shuffle(&mut rng);
i = 0;
while i < ruleset_copy.len() {
rule = ruleset_copy.remove(i);
start = expression.parse().unwrap();
// end = expression.1.parse().unwrap();
runner = Runner::default()
.with_iter_limit(params.0)
.with_node_limit(params.1)
.with_time_limit(Duration::new(params.2, 0))
.with_expr(&start);

if use_iteration_check {
runner = runner.run_check_iteration(ruleset_copy.iter(), &goals);
} else {
runner = runner.run(ruleset_copy.iter());
}
id = runner.egraph.find(*runner.roots.last().unwrap());
// matches = end.search_eclass(&runner.egraph, id);
matches = goals.iter().all(|goal| {
let mat = goal.search_eclass(&runner.egraph, id);
if !mat.is_none() {
proved_goal = goal.to_string();
}
mat.is_none()
});
if matches {
ruleset_copy.insert(i, rule);
i += 1;
}
mat.is_none()
});
if matches {
ruleset_copy.insert(i, rule);
i += 1;
} else {
proved_once = true;
}
if ruleset_copy.len() < ruleset_minimal.len() {
ruleset_minimal = ruleset_copy.clone();
}
counter += 1;
}
if ruleset_copy.len() < ruleset_minimal.len() {
ruleset_minimal = ruleset_copy.clone();
}
counter += 1;
}
if proved_once {
ruleset_copy_names = ruleset_minimal
.clone()
.into_iter()
Expand Down
25 changes: 7 additions & 18 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,19 @@ fn test_classes(

fn main() {
let _args: Vec<String> = env::args().collect();
// let expressions = vec![
// ("( <= ( - v0 11 ) ( + ( * ( / ( - v0 v1 ) 12 ) 12 ) v1 ) )","1"),
// ("( <= ( + ( / ( - v0 v1 ) 8 ) 32 ) ( max ( / ( + ( - v0 v1 ) 257 ) 8 ) 0 ) )","1"),
// ("( <= (/ a 2) (a))", "1"),
// ("( <= ( min ( + ( * ( + v0 v1 ) 161 ) ( + ( min v2 v3 ) v4 ) ) v5 ) ( + ( * ( + v0 v1 ) 161 ) ( + v2 v4 ) ) )","1"),
// ("( == (+ a b) (+ b a) )","1"),
// ("( == (min a b) (a))","1"),
// ];
// generate_dataset(expressions,(30, 10000, 5), 2, 2);
// generate_dataset_par(&expressions,(30, 10000, 5), 2, 10);
// println!("Printing rules ...");
// let arr = filteredRules(&get_first_arg().unwrap(), 1).unwrap();
// for rule in arr{
// println!("{}", rule.name());
// }
// println!("End.");

if _args.len() > 4 {
let operation = get_nth_arg(1).unwrap();
let expressions_file = get_nth_arg(2).unwrap();
let params = get_runner_params(3).unwrap();
match operation.to_str().unwrap() {
"dataset" => {
dataset::generation_execution(&expressions_file, params, 5, 500);
// cargo run --release dataset ./results/expressions_egg.csv 100000 100000 5 5 1000 0 48
let reorder_count = get_nth_arg(6).unwrap().into_string().unwrap().parse::<usize>().unwrap();
let batch_size = get_nth_arg(7).unwrap().into_string().unwrap().parse::<usize>().unwrap();
let continue_from_expr = get_nth_arg(8).unwrap().into_string().unwrap().parse::<usize>().unwrap();
let cores = get_nth_arg(9).unwrap().into_string().unwrap().parse::<usize>().unwrap();
rayon::ThreadPoolBuilder::new().num_threads(cores).build_global().unwrap();
dataset::generation_execution(&expressions_file, params, reorder_count, batch_size, continue_from_expr);
}
"prove_exprs" => {
let expression_vect = read_expressions(&expressions_file).unwrap();
Expand Down
27 changes: 15 additions & 12 deletions utils/infix-to-prefix/Expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,23 @@ def infixToPrefix(self):

@staticmethod
def partition(arr, low, high):
i = (low - 1) # index of smaller element
pivot = arr[high] # pivot

for j in range(low, high):
# If current element is bigger than or
# equal to pivot
if Expression.priority(arr[j][0]) > Expression.priority(pivot[0]):
# increment index of smaller element
i = i + 1
arr[i], arr[j] = arr[j], arr[i]
elif Expression.priority(arr[j][0]) == Expression.priority(pivot[0]):
if arr[j][1] < pivot[1]:
try:
i = (low - 1) # index of smaller element
pivot = arr[high] # pivot

for j in range(low, high):
# If current element is bigger than or
# equal to pivot
if Expression.priority(arr[j][0]) > Expression.priority(pivot[0]):
# increment index of smaller element
i = i + 1
arr[i], arr[j] = arr[j], arr[i]
elif Expression.priority(arr[j][0]) == Expression.priority(pivot[0]):
if arr[j][1] < pivot[1]:
i = i + 1
arr[i], arr[j] = arr[j], arr[i]
except RecursionError as re:
print('Recursion problem!!!!')



Expand Down
61 changes: 36 additions & 25 deletions utils/infix-to-prefix/run_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,57 @@
from Expression import Expression
from Stack import Stack
import csv
from joblib import Parallel, delayed
import multiprocessing


def extract(path):
print(path)
def extract(path, delimiter):
num_cores = multiprocessing.cpu_count() // 2
with open(path) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=',')
csv_reader = csv.reader(csv_file, delimiter=delimiter)
remove = ['int32', 'float32', 'select',
'broadcast', 'ramp', 'fold',
'Overflow', 'can_prove', 'canprove'
'op->type', 'op->type', 'Call', 'this', 'IRMatcher']
exprs = []

for i, row in enumerate(csv_reader):
next_expr = False
for tabou in remove:
if tabou in row[0]:
# print("=====", tabou)
next_expr = True
if next_expr:
# print("Skipped row :", i)
continue
row[0] = row[0].replace("(uint1)", "")
right = Expression(row[0])
expr = ' '.join(right.infixToPrefix())
expr = re.sub(
"\( \- (?P<var>[a-zA-Z_$][a-zA-Z_$0-9]*) \)", r'(* \1 -1)', expr)
print(expr)
exprs.append(expr)
exprs = Parallel(n_jobs=num_cores)(delayed(extract_one)(i, row, remove) for i, row in enumerate(csv_reader))
#for i, row in enumerate(csv_reader):
# exprs.append(extract_one(i, row, remove))
return exprs

def extract_one(i, row, remove):
try:
if len(row[0]) > 1000:
raise Exception("Expression "+ str(i) +" skipped.")
next_expr = False
for tabou in remove:
if tabou in row[0]:
# print("=====", tabou)
next_expr = True
if next_expr:
# print("Skipped row :", i)
return None
row[0] = row[0].replace("(uint1)", "")
right = Expression(row[0])
expr = ' '.join(right.infixToPrefix())
expr = re.sub(
"\( \- (?P<var>[a-zA-Z_$][a-zA-Z_$0-9]*) \)", r'(* \1 -1)', expr)
print("Expression "+ str(i) +" processed.")
return expr
except:
print("Expression "+ str(i) +" skipped.")

if __name__ == '__main__':
exprs = extract(sys.argv[1])
if len(sys.argv) > 2:
delimiter = sys.argv[2]
else:
delimiter = ','
exprs = extract(sys.argv[1], delimiter)
# exprs = [item for sublist in exprs for item in sublist]
exprs = [i for i in exprs if i]
frmt = []
for i, expr in enumerate(exprs):
frmt.append([i+1, expr])
# for i, rule in enumerate(rules):
# rul = Rule(rule[0])
# rules_trs.append([i+1, rul.toString(), *rul.infix_rule(), rule[1]])
# print(rules_trs)
with open('results/expressions_egg.csv', 'w') as f:
# using csv.writer method from CSV package
write = csv.writer(f)
Expand Down