Skip to content

Commit 0b9f575

Browse files
authored
Merge pull request #2488 from anutosh491/symbolic_compare
Fixing symbolic compare for test_gruntz.py
2 parents 6285062 + 59599c4 commit 0b9f575

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

integration_tests/test_gruntz.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,28 @@
22
from sympy import Symbol
33

44
def mmrv(e: S, x: S) -> list[S]:
5-
l: list[S] = []
65
if not e.has(x):
7-
return l
6+
list0: list[S] = []
7+
return list0
8+
elif e == x:
9+
list1: list[S] = [x]
10+
return list1
811
else:
912
raise
1013

11-
def test_mrv1():
14+
def test_mrv():
15+
# Case 1
1216
x: S = Symbol("x")
1317
y: S = Symbol("y")
14-
ans: list[S] = mmrv(y, x)
15-
assert len(ans) == 0
18+
ans1: list[S] = mmrv(y, x)
19+
print(ans1)
20+
assert len(ans1) == 0
1621

17-
test_mrv1()
22+
# Case 2
23+
ans2: list[S] = mmrv(x, x)
24+
ele1: S = ans2[0]
25+
print(ele1)
26+
assert ele1 == x
27+
assert len(ans2) == 1
28+
29+
test_mrv()

src/libasr/pass/replace_symbolic.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,15 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
690690
xx.m_test = new_logical_not;
691691
}
692692
}
693+
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*xx.m_test)) {
694+
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(xx.m_test);
695+
ASR::expr_t* function_call = nullptr;
696+
if (s->m_op == ASR::cmpopType::Eq) {
697+
function_call = basic_compare(xx.base.base.loc, "basic_eq", s->m_left, s->m_right);
698+
} else {
699+
function_call = basic_compare(xx.base.base.loc, "basic_neq", s->m_left, s->m_right);
700+
}
701+
xx.m_test = function_call;
693702
}
694703
}
695704

0 commit comments

Comments
 (0)