Skip to content

Commit eaf402c

Browse files
authored
Merge pull request #2633 from kmr-srbh/set-discard-method
Implement `set.discard(elem)`
2 parents 793997d + 3a9424e commit eaf402c

File tree

9 files changed

+172
-40
lines changed

9 files changed

+172
-40
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,7 @@ RUN(NAME test_dict_nested1 LABELS cpython llvm)
583583
RUN(NAME test_set_len LABELS cpython llvm)
584584
RUN(NAME test_set_add LABELS cpython llvm)
585585
RUN(NAME test_set_remove LABELS cpython llvm)
586+
RUN(NAME test_set_discard LABELS cpython llvm)
586587
RUN(NAME test_global_set LABELS cpython llvm)
587588
RUN(NAME test_for_loop LABELS cpython llvm c)
588589
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from lpython import i32
2+
3+
def test_set_discard():
4+
s1: set[i32]
5+
s2: set[tuple[i32, tuple[i32, i32], str]]
6+
s3: set[str]
7+
st1: str
8+
i: i32
9+
j: i32
10+
k: i32
11+
12+
for k in range(2):
13+
s1 = {0}
14+
s2 = {(0, (1, 2), "a")}
15+
for i in range(20):
16+
j = i % 10
17+
s1.add(j)
18+
s2.add((j, (j + 1, j + 2), "a"))
19+
20+
for i in range(10):
21+
s1.discard(i)
22+
s2.discard((i, (i + 1, i + 2), "a"))
23+
assert len(s1) == 10 - 1 - i
24+
assert len(s1) == len(s2)
25+
26+
st1 = "a"
27+
s3 = {st1}
28+
for i in range(20):
29+
s3.add(st1)
30+
if i < 10:
31+
if i > 0:
32+
st1 += "a"
33+
34+
st1 = "a"
35+
for i in range(10):
36+
s3.discard(st1)
37+
assert len(s3) == 10 - 1 - i
38+
if i < 10:
39+
st1 += "a"
40+
41+
for i in range(20):
42+
s1.add(i)
43+
if i % 2 == 0:
44+
s1.discard(i)
45+
assert len(s1) == (i + 1) // 2
46+
47+
48+
test_set_discard()

src/libasr/ASR.asdl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ stmt
6969
| BlockCall(int label, symbol m)
7070
| SetInsert(expr a, expr ele)
7171
| SetRemove(expr a, expr ele)
72+
| SetDiscard(expr a, expr ele)
7273
| ListInsert(expr a, expr pos, expr ele)
7374
| ListRemove(expr a, expr ele)
7475
| ListClear(expr a)

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,7 +1905,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
19051905
llvm_utils->set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx);
19061906
}
19071907

1908-
void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
1908+
void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele, bool throw_key_error) {
19091909
ASR::Set_t* set_type = ASR::down_cast<ASR::Set_t>(
19101910
ASRUtils::expr_type(m_arg));
19111911
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
@@ -1919,7 +1919,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
19191919
ptr_loads = ptr_loads_copy;
19201920
llvm::Value *el = tmp;
19211921
llvm_utils->set_set_api(set_type);
1922-
llvm_utils->set_api->remove_item(pset, el, *module, asr_el_type);
1922+
llvm_utils->set_api->remove_item(pset, el, *module, asr_el_type, throw_key_error);
19231923
}
19241924

19251925
void visit_IntrinsicElementalFunction(const ASR::IntrinsicElementalFunction_t& x) {
@@ -1986,7 +1986,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
19861986
break;
19871987
}
19881988
case ASRUtils::IntrinsicElementalFunctions::SetRemove: {
1989-
generate_SetRemove(x.m_args[0], x.m_args[1]);
1989+
generate_SetRemove(x.m_args[0], x.m_args[1], true);
1990+
break;
1991+
}
1992+
case ASRUtils::IntrinsicElementalFunctions::SetDiscard: {
1993+
generate_SetRemove(x.m_args[0], x.m_args[1], false);
19901994
break;
19911995
}
19921996
case ASRUtils::IntrinsicElementalFunctions::Exp: {

src/libasr/codegen/llvm_utils.cpp

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6415,7 +6415,7 @@ namespace LCompilers {
64156415

64166416
void LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check(
64176417
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
6418-
llvm::Module& module, ASR::ttype_t* el_asr_type) {
6418+
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) {
64196419

64206420
/**
64216421
* C++ equivalent:
@@ -6467,14 +6467,16 @@ namespace LCompilers {
64676467
llvm_utils->create_if_else(is_el_matching, [=]() {
64686468
LLVM::CreateStore(*builder, el_hash, pos_ptr);
64696469
}, [&]() {
6470-
std::string message = "The set does not contain the specified element";
6471-
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
6472-
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
6473-
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
6474-
int exit_code_int = 1;
6475-
llvm::Value *exit_code = llvm::ConstantInt::get(context,
6476-
llvm::APInt(32, exit_code_int));
6477-
exit(context, module, *builder, exit_code);
6470+
if (throw_key_error) {
6471+
std::string message = "The set does not contain the specified element";
6472+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
6473+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
6474+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
6475+
int exit_code_int = 1;
6476+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
6477+
llvm::APInt(32, exit_code_int));
6478+
exit(context, module, *builder, exit_code);
6479+
}
64786480
});
64796481
}
64806482
builder->CreateBr(mergeBB);
@@ -6491,20 +6493,22 @@ namespace LCompilers {
64916493
LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type);
64926494

64936495
llvm_utils->create_if_else(is_el_matching, []() {}, [&]() {
6494-
std::string message = "The set does not contain the specified element";
6495-
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
6496-
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
6497-
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
6498-
int exit_code_int = 1;
6499-
llvm::Value *exit_code = llvm::ConstantInt::get(context,
6500-
llvm::APInt(32, exit_code_int));
6501-
exit(context, module, *builder, exit_code);
6496+
if (throw_key_error) {
6497+
std::string message = "The set does not contain the specified element";
6498+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
6499+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
6500+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
6501+
int exit_code_int = 1;
6502+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
6503+
llvm::APInt(32, exit_code_int));
6504+
exit(context, module, *builder, exit_code);
6505+
}
65026506
});
65036507
}
65046508

65056509
void LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check(
65066510
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
6507-
llvm::Module& module, ASR::ttype_t* el_asr_type) {
6511+
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) {
65086512
/**
65096513
* C++ equivalent:
65106514
*
@@ -6532,20 +6536,22 @@ namespace LCompilers {
65326536
);
65336537

65346538
llvm_utils->create_if_else(does_el_exist, []() {}, [&]() {
6535-
std::string message = "The set does not contain the specified element";
6536-
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
6537-
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
6538-
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
6539-
int exit_code_int = 1;
6540-
llvm::Value *exit_code = llvm::ConstantInt::get(context,
6541-
llvm::APInt(32, exit_code_int));
6542-
exit(context, module, *builder, exit_code);
6539+
if (throw_key_error) {
6540+
std::string message = "The set does not contain the specified element";
6541+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
6542+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
6543+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
6544+
int exit_code_int = 1;
6545+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
6546+
llvm::APInt(32, exit_code_int));
6547+
exit(context, module, *builder, exit_code);
6548+
}
65436549
});
65446550
}
65456551

65466552
void LLVMSetLinearProbing::remove_item(
65476553
llvm::Value* set, llvm::Value* el,
6548-
llvm::Module& module, ASR::ttype_t* el_asr_type) {
6554+
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) {
65496555
/**
65506556
* C++ equivalent:
65516557
*
@@ -6555,7 +6561,7 @@ namespace LCompilers {
65556561
*/
65566562
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
65576563
llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, module);
6558-
this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type);
6564+
this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type, throw_key_error);
65596565
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
65606566
llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set));
65616567
llvm::Value* el_mask_i = llvm_utils->create_ptr_gep(el_mask, pos);
@@ -6571,7 +6577,7 @@ namespace LCompilers {
65716577

65726578
void LLVMSetSeparateChaining::remove_item(
65736579
llvm::Value* set, llvm::Value* el,
6574-
llvm::Module& module, ASR::ttype_t* el_asr_type) {
6580+
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) {
65756581
/**
65766582
* C++ equivalent:
65776583
*
@@ -6593,7 +6599,7 @@ namespace LCompilers {
65936599

65946600
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
65956601
llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, module);
6596-
this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type);
6602+
this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type, throw_key_error);
65976603
llvm::Value* prev = LLVM::CreateLoad(*builder, chain_itr_prev);
65986604
llvm::Value* found = LLVM::CreateLoad(*builder, chain_itr);
65996605

src/libasr/codegen/llvm_utils.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -967,12 +967,12 @@ namespace LCompilers {
967967
virtual
968968
void resolve_collision_for_read_with_bound_check(
969969
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
970-
llvm::Module& module, ASR::ttype_t* el_asr_type) = 0;
970+
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) = 0;
971971

972972
virtual
973973
void remove_item(
974974
llvm::Value* set, llvm::Value* el,
975-
llvm::Module& module, ASR::ttype_t* el_asr_type) = 0;
975+
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) = 0;
976976

977977
virtual
978978
void set_deepcopy(
@@ -1038,11 +1038,11 @@ namespace LCompilers {
10381038

10391039
void resolve_collision_for_read_with_bound_check(
10401040
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
1041-
llvm::Module& module, ASR::ttype_t* el_asr_type);
1041+
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);
10421042

10431043
void remove_item(
10441044
llvm::Value* set, llvm::Value* el,
1045-
llvm::Module& module, ASR::ttype_t* el_asr_type);
1045+
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);
10461046

10471047
void set_deepcopy(
10481048
llvm::Value* src, llvm::Value* dest,
@@ -1119,11 +1119,11 @@ namespace LCompilers {
11191119

11201120
void resolve_collision_for_read_with_bound_check(
11211121
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
1122-
llvm::Module& module, ASR::ttype_t* el_asr_type);
1122+
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);
11231123

11241124
void remove_item(
11251125
llvm::Value* set, llvm::Value* el,
1126-
llvm::Module& module, ASR::ttype_t* el_asr_type);
1126+
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);
11271127

11281128
void set_deepcopy(
11291129
llvm::Value* src, llvm::Value* dest,

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ inline std::string get_intrinsic_name(int x) {
110110
INTRINSIC_NAME_CASE(DictValues)
111111
INTRINSIC_NAME_CASE(SetAdd)
112112
INTRINSIC_NAME_CASE(SetRemove)
113+
INTRINSIC_NAME_CASE(SetDiscard)
113114
INTRINSIC_NAME_CASE(Max)
114115
INTRINSIC_NAME_CASE(Min)
115116
INTRINSIC_NAME_CASE(Sign)
@@ -343,6 +344,8 @@ namespace IntrinsicElementalFunctionRegistry {
343344
{nullptr, &SetAdd::verify_args}},
344345
{static_cast<int64_t>(IntrinsicElementalFunctions::SetRemove),
345346
{nullptr, &SetRemove::verify_args}},
347+
{static_cast<int64_t>(IntrinsicElementalFunctions::SetDiscard),
348+
{nullptr, &SetDiscard::verify_args}},
346349
{static_cast<int64_t>(IntrinsicElementalFunctions::Max),
347350
{&Max::instantiate_Max, &Max::verify_args}},
348351
{static_cast<int64_t>(IntrinsicElementalFunctions::Min),
@@ -630,6 +633,8 @@ namespace IntrinsicElementalFunctionRegistry {
630633
"set.add"},
631634
{static_cast<int64_t>(IntrinsicElementalFunctions::SetRemove),
632635
"set.remove"},
636+
{static_cast<int64_t>(IntrinsicElementalFunctions::SetDiscard),
637+
"set.discard"},
633638
{static_cast<int64_t>(IntrinsicElementalFunctions::Max),
634639
"max"},
635640
{static_cast<int64_t>(IntrinsicElementalFunctions::Min),
@@ -823,6 +828,7 @@ namespace IntrinsicElementalFunctionRegistry {
823828
{"dict.values", {&DictValues::create_DictValues, &DictValues::eval_dict_values}},
824829
{"set.add", {&SetAdd::create_SetAdd, &SetAdd::eval_set_add}},
825830
{"set.remove", {&SetRemove::create_SetRemove, &SetRemove::eval_set_remove}},
831+
{"set.discard", {&SetDiscard::create_SetDiscard, &SetDiscard::eval_set_discard}},
826832
{"max0", {&Max::create_Max, &Max::eval_Max}},
827833
{"adjustl", {&Adjustl::create_Adjustl, &Adjustl::eval_Adjustl}},
828834
{"adjustr", {&Adjustr::create_Adjustr, &Adjustr::eval_Adjustr}},

src/libasr/pass/intrinsic_functions.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ enum class IntrinsicElementalFunctions : int64_t {
108108
DictValues,
109109
SetAdd,
110110
SetRemove,
111+
SetDiscard,
111112
Max,
112113
Min,
113114
Radix,
@@ -4916,6 +4917,57 @@ static inline ASR::asr_t* create_SetRemove(Allocator& al, const Location& loc,
49164917

49174918
} // namespace SetRemove
49184919

4920+
namespace SetDiscard {
4921+
4922+
static inline void verify_args(const ASR::IntrinsicElementalFunction_t& x, diag::Diagnostics& diagnostics) {
4923+
ASRUtils::require_impl(x.n_args == 2, "Call to set.discard must have exactly one argument",
4924+
x.base.base.loc, diagnostics);
4925+
ASRUtils::require_impl(ASR::is_a<ASR::Set_t>(*ASRUtils::expr_type(x.m_args[0])),
4926+
"First argument to set.discard must be of set type",
4927+
x.base.base.loc, diagnostics);
4928+
ASRUtils::require_impl(ASRUtils::check_equal_type(ASRUtils::expr_type(x.m_args[1]),
4929+
ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_args[0]))),
4930+
"Second argument to set.discard must be of same type as set's element type",
4931+
x.base.base.loc, diagnostics);
4932+
ASRUtils::require_impl(x.m_type == nullptr,
4933+
"Return type of set.discard must be empty",
4934+
x.base.base.loc, diagnostics);
4935+
}
4936+
4937+
static inline ASR::expr_t *eval_set_discard(Allocator &/*al*/,
4938+
const Location &/*loc*/, ASR::ttype_t *, Vec<ASR::expr_t*>& /*args*/, diag::Diagnostics& /*diag*/) {
4939+
// TODO: To be implemented for SetConstant expression
4940+
return nullptr;
4941+
}
4942+
4943+
static inline ASR::asr_t* create_SetDiscard(Allocator& al, const Location& loc,
4944+
Vec<ASR::expr_t*>& args,
4945+
diag::Diagnostics& diag) {
4946+
if (args.size() != 2) {
4947+
append_error(diag, "Call to set.discard must have exactly one argument", loc);
4948+
return nullptr;
4949+
}
4950+
if (!ASRUtils::check_equal_type(ASRUtils::expr_type(args[1]),
4951+
ASRUtils::get_contained_type(ASRUtils::expr_type(args[0])))) {
4952+
append_error(diag, "Argument to set.discard must be of same type as set's "
4953+
"element type", loc);
4954+
return nullptr;
4955+
}
4956+
4957+
Vec<ASR::expr_t*> arg_values;
4958+
arg_values.reserve(al, args.size());
4959+
for( size_t i = 0; i < args.size(); i++ ) {
4960+
arg_values.push_back(al, ASRUtils::expr_value(args[i]));
4961+
}
4962+
ASR::expr_t* compile_time_value = eval_set_discard(al, loc, nullptr, arg_values, diag);
4963+
return ASR::make_Expr_t(al, loc,
4964+
ASRUtils::EXPR(ASR::make_IntrinsicElementalFunction_t(al, loc,
4965+
static_cast<int64_t>(IntrinsicElementalFunctions::SetDiscard),
4966+
args.p, args.size(), 0, nullptr, compile_time_value)));
4967+
}
4968+
4969+
} // namespace SetRemove
4970+
49194971
namespace Max {
49204972

49214973
static inline void verify_args(const ASR::IntrinsicElementalFunction_t& x, diag::Diagnostics& diagnostics) {

0 commit comments

Comments
 (0)