Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/irx/builders/llvmliteir.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,25 @@ def visit(self, node: astx.UnaryOp) -> None:
self.result_stack.append(result)
return

elif node.op_code == "~":
self.visit(node.operand)
operand_val = safe_pop(self.result_stack)
# Bitwise NOT: xor with all bits set (-1)
result = self._llvm.ir_builder.xor(
operand_val, ir.Constant(operand_val.type, -1), "bitnottmp"
)
if isinstance(node.operand, astx.Identifier):
if node.operand.name in self.const_vars:
raise Exception(
f"Cannot mutate '{node.operand.name}':"
"declared as constant"
)
addr = self.named_values.get(node.operand.name)
if addr:
self._llvm.ir_builder.store(result, addr)
self.result_stack.append(result)
return

raise Exception(f"Unary operator {node.op_code} not implemented yet.")

@dispatch # type: ignore[no-redef]
Expand Down Expand Up @@ -1033,6 +1052,27 @@ def visit(self, node: astx.BinaryOp) -> None:
result = self._llvm.ir_builder.or_(llvm_lhs, llvm_rhs, "ortmp")
self.result_stack.append(result)
return

if node.op_code == "&":
result = self._llvm.ir_builder.and_(llvm_lhs, llvm_rhs, "bitandtmp")
self.result_stack.append(result)
return
elif node.op_code == "|":
result = self._llvm.ir_builder.or_(llvm_lhs, llvm_rhs, "bitortmp")
self.result_stack.append(result)
return
elif node.op_code == "^":
result = self._llvm.ir_builder.xor(llvm_lhs, llvm_rhs, "bitxortmp")
self.result_stack.append(result)
return
elif node.op_code == "<<":
result = self._llvm.ir_builder.shl(llvm_lhs, llvm_rhs, "shltmp")
self.result_stack.append(result)
return
elif node.op_code == ">>":
result = self._llvm.ir_builder.ashr(llvm_lhs, llvm_rhs, "shrtmp")
self.result_stack.append(result)
return

if node.op_code == "+":
# note: it should be according the datatype,
Expand Down
105 changes: 105 additions & 0 deletions tests/test_bitwise_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
title: Tests for Bitwise AND (&) operator.
"""

import astx
import pytest

from irx.builders.base import Builder
from irx.builders.llvmliteir import LLVMLiteIR
from irx.system import PrintExpr
from .conftest import check_result

@pytest.mark.parametrize(
"op,lhs,rhs,expected",
[
("&", 6, 3, lambda a, b: str(a & b)),
("|", 6, 3, lambda a, b: str(a | b)),
("^", 6, 3, lambda a, b: str(a ^ b)),
("<<", 3, 2, lambda a, b: str(a << b)),
(">>", 12, 2, lambda a, b: str(a >> b)),
],
)
@pytest.mark.parametrize(
"int_type,literal_type",
[
(astx.Int32, astx.LiteralInt32),
(astx.Int16, astx.LiteralInt16),
(astx.Int8, astx.LiteralInt8),
(astx.Int64, astx.LiteralInt64),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add float here too

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid I don't think python supports bitwise operators for floats

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AryanBhirud you are right that bitwise operators don't apply to floats but i guess @yuvimittal meant adding unsigned integer types here? e.g. (astx.UInt8, astx.LiteralUInt8),(astx.UInt16, astx.LiteralUInt16), etc. That would also help catch the >> signed vs unsigned behavior, since ashr and lshr produce different results

],
)
@pytest.mark.parametrize("builder_class", [LLVMLiteIR])
def test_bitwise_binary_ops(
builder_class: type[Builder],
int_type: type,
literal_type: type,
op: str,
lhs: int,
rhs: int,
expected,
) -> None:
builder = builder_class()
module = builder.module()
left = literal_type(lhs)
right = literal_type(rhs)
expr = astx.BinaryOp(op, left, right)
decl = astx.VariableDeclaration(
name="result", type_=int_type(), value=expr, mutability=astx.MutabilityKind.mutable
)
main_proto = astx.FunctionPrototype(
name="main", args=astx.Arguments(), return_type=int_type()
)
main_block = astx.Block()
main_block.append(decl)
main_block.append(PrintExpr(astx.Identifier("result")))
main_block.append(astx.FunctionReturn(literal_type(0)))
main_fn = astx.FunctionDef(prototype=main_proto, body=main_block)
module.block.append(main_fn)
check_result("build", builder, module, expected_output=expected(lhs, rhs))

@pytest.mark.parametrize(
"val",
[15, 0, 1, 255],
)
@pytest.mark.parametrize(
"int_type,literal_type,bitwidth",
[
(astx.Int8, astx.LiteralInt8, 8),
(astx.Int16, astx.LiteralInt16, 16),
(astx.Int32, astx.LiteralInt32, 32),
(astx.Int64, astx.LiteralInt64, 64),
],
)
@pytest.mark.parametrize("builder_class", [LLVMLiteIR])
def test_bitwise_not(
builder_class: type[Builder],
int_type: type,
literal_type: type,
bitwidth: int,
val: int,
) -> None:
builder = builder_class()
module = builder.module()
expr = astx.UnaryOp("~", literal_type(val))
decl = astx.VariableDeclaration(
name="result", type_=int_type(), value=expr, mutability=astx.MutabilityKind.mutable
)
main_proto = astx.FunctionPrototype(
name="main", args=astx.Arguments(), return_type=int_type()
)
main_block = astx.Block()
main_block.append(decl)
main_block.append(PrintExpr(astx.Identifier("result")))
main_block.append(astx.FunctionReturn(literal_type(0)))
main_fn = astx.FunctionDef(prototype=main_proto, body=main_block)
module.block.append(main_fn)
# Compute the expected result as a signed integer of the correct width
mask = (1 << bitwidth) - 1
unsigned = (~val) & mask
sign_bit = 1 << (bitwidth - 1)
if unsigned & sign_bit:
expected = str(unsigned - (1 << bitwidth))
else:
expected = str(unsigned)
check_result("build", builder, module, expected_output=expected)
Loading