Skip to content

Commit cf1160b

Browse files
authored
Update array indexing to use strides and offsets (#1765)
1 parent 42e0f59 commit cf1160b

File tree

6 files changed

+97
-43
lines changed

6 files changed

+97
-43
lines changed

src/libasr/asdl_cpp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2458,10 +2458,15 @@ def make_visitor(self, name, fields):
24582458
LCOMPILERS_ASSERT(e->m_external);
24592459
LCOMPILERS_ASSERT(!ASR::is_a<ASR::ExternalSymbol_t>(*e->m_external));
24602460
s = e->m_external;
2461-
} else if (s->type == ASR::symbolType::Function) {
2461+
}
2462+
if (s->type == ASR::symbolType::Function) {
24622463
return ASR::down_cast<ASR::Function_t>(s)->m_function_signature;
2464+
} else if( s->type == ASR::symbolType::Variable ) {
2465+
return ASR::down_cast<ASR::Variable_t>(s)->m_type;
2466+
} else {
2467+
LCOMPILERS_ASSERT_MSG(false, std::to_string(s->type));
24632468
}
2464-
return ASR::down_cast<ASR::Variable_t>(s)->m_type;
2469+
return nullptr;
24652470
}""" \
24662471
% (name, name), 2, new_line=False)
24672472
elif name == "OverloadedBinOp":

src/libasr/asr_verify.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,12 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
681681
void handle_ArrayItemSection(const T &x) {
682682
visit_expr(*x.m_v);
683683
for (size_t i=0; i<x.n_args; i++) {
684+
if( x.m_args[i].m_step != nullptr ) {
685+
require_with_loc(x.m_args[i].m_left != nullptr &&
686+
x.m_args[i].m_right != nullptr,
687+
"Sliced dimension should always have lower and "
688+
"upper bounds present.", x.base.base.loc);
689+
}
684690
visit_array_index(x.m_args[i]);
685691
}
686692
require(x.m_type != nullptr,

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4357,6 +4357,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
43574357
ASR::ttype_t* fptr_type = ASRUtils::expr_type(fptr);
43584358
llvm::Type* llvm_fptr_type = get_type_from_ttype_t_util(ASRUtils::get_contained_type(fptr_type));
43594359
llvm::Value* fptr_array = builder->CreateAlloca(llvm_fptr_type);
4360+
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 0)),
4361+
arr_descr->get_offset(fptr_array, false));
43604362
ASR::dimension_t* fptr_dims;
43614363
int fptr_rank = ASRUtils::extract_dimensions_from_ttype(
43624364
ASRUtils::expr_type(fptr),
@@ -4377,16 +4379,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
43774379
}
43784380
llvm_cptr = builder->CreateBitCast(llvm_cptr, llvm_fptr_data_type->getPointerTo());
43794381
builder->CreateStore(llvm_cptr, fptr_data);
4382+
llvm::Value* prod = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
43804383
for( int i = 0; i < fptr_rank; i++ ) {
43814384
llvm::Value* curr_dim = llvm::ConstantInt::get(context, llvm::APInt(32, i));
43824385
llvm::Value* desi = arr_descr->get_pointer_to_dimension_descriptor(fptr_des, curr_dim);
4386+
llvm::Value* desi_stride = arr_descr->get_stride(desi, false);
43834387
llvm::Value* desi_lb = arr_descr->get_lower_bound(desi, false);
43844388
llvm::Value* desi_size = arr_descr->get_dimension_size(fptr_des, curr_dim, false);
4389+
builder->CreateStore(prod, desi_stride);
43854390
llvm::Value* i32_one = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
43864391
llvm::Value* new_lb = i32_one;
43874392
llvm::Value* new_ub = shape_data ? CreateLoad(llvm_utils->create_ptr_gep(shape_data, i)) : i32_one;
43884393
builder->CreateStore(new_lb, desi_lb);
4389-
builder->CreateStore(builder->CreateAdd(builder->CreateSub(new_ub, new_lb), i32_one), desi_size);
4394+
llvm::Value* new_size = builder->CreateAdd(builder->CreateSub(new_ub, new_lb), i32_one);
4395+
builder->CreateStore(new_size, desi_size);
4396+
prod = builder->CreateMul(prod, new_size);
43904397
}
43914398
} else {
43924399
int64_t ptr_loads_copy = ptr_loads;

src/libasr/codegen/llvm_array_utils.cpp

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -287,24 +287,20 @@ namespace LCompilers {
287287
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, n_dims)), get_rank(arr, true));
288288
builder->CreateStore(dim_des_first, dim_des_val);
289289
dim_des_val = LLVM::CreateLoad(*builder, dim_des_val);
290+
llvm::Value* prod = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
290291
for( int r = 0; r < n_dims; r++ ) {
291292
llvm::Value* dim_val = llvm_utils->create_ptr_gep(dim_des_val, r);
292293
llvm::Value* s_val = llvm_utils->create_gep(dim_val, 0);
293294
llvm::Value* l_val = llvm_utils->create_gep(dim_val, 1);
294295
llvm::Value* dim_size_ptr = llvm_utils->create_gep(dim_val, 2);
295-
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 1)), s_val);
296+
builder->CreateStore(prod, s_val);
296297
builder->CreateStore(llvm_dims[r].first, l_val);
297298
llvm::Value* dim_size = llvm_dims[r].second;
299+
prod = builder->CreateMul(prod, dim_size);
298300
builder->CreateStore(dim_size, dim_size_ptr);
299301
}
300302

301303
llvm::Value* llvm_size = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
302-
llvm::Value* const_1 = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
303-
llvm::Value* prod = const_1;
304-
for( int r = 0; r < n_dims; r++ ) {
305-
llvm::Value* dim_size = llvm_dims[r].second;
306-
prod = builder->CreateMul(prod, dim_size);
307-
}
308304
builder->CreateStore(prod, llvm_size);
309305
llvm::Value* first_ptr = get_pointer_to_data(arr);
310306
llvm::Value* arr_first = builder->CreateAlloca(llvm_data_type,
@@ -340,12 +336,12 @@ namespace LCompilers {
340336
llvm::Value* arr, llvm::Type* llvm_data_type, int n_dims,
341337
std::vector<std::pair<llvm::Value*, llvm::Value*>>& llvm_dims,
342338
llvm::Module* module) {
343-
llvm::Value* num_elements = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
344339
llvm::Value* offset_val = llvm_utils->create_gep(arr, 1);
345340
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 0)),
346341
offset_val);
347342
set_is_allocated_flag(arr, 1);
348343
llvm::Value* dim_des_val = LLVM::CreateLoad(*builder, llvm_utils->create_gep(arr, 2));
344+
llvm::Value* prod = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
349345
for( int r = 0; r < n_dims; r++ ) {
350346
llvm::Value* dim_val = llvm_utils->create_ptr_gep(dim_des_val, r);
351347
llvm::Value* s_val = llvm_utils->create_gep(dim_val, 0);
@@ -354,17 +350,17 @@ namespace LCompilers {
354350
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 1)), s_val);
355351
builder->CreateStore(llvm_dims[r].first, l_val);
356352
llvm::Value* dim_size = llvm_dims[r].second;
357-
num_elements = builder->CreateMul(num_elements, dim_size);
358353
builder->CreateStore(dim_size, dim_size_ptr);
354+
prod = builder->CreateMul(prod, dim_size);
359355
}
360356
llvm::Value* ptr2firstptr = get_pointer_to_data(arr);
361357
llvm::AllocaInst *arg_size = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
362358
llvm::DataLayout data_layout(module);
363359
llvm::Type* ptr_type = llvm_data_type->getPointerTo();
364360
uint64_t size = data_layout.getTypeAllocSize(llvm_data_type);
365361
llvm::Value* llvm_size = llvm::ConstantInt::get(context, llvm::APInt(32, size));
366-
num_elements = builder->CreateMul(num_elements, llvm_size);
367-
builder->CreateStore(num_elements, arg_size);
362+
prod = builder->CreateMul(prod, llvm_size);
363+
builder->CreateStore(prod, arg_size);
368364
llvm::Value* ptr_as_char_ptr = lfortran_malloc(context, *module, *builder, LLVM::CreateLoad(*builder, arg_size));
369365
llvm::Value* first_ptr = builder->CreateBitCast(ptr_as_char_ptr, ptr_type);
370366
builder->CreateStore(first_ptr, ptr2firstptr);
@@ -390,8 +386,12 @@ namespace LCompilers {
390386
return llvm_utils->create_gep(arr, 0);
391387
}
392388

393-
llvm::Value* SimpleCMODescriptor::get_offset(llvm::Value* arr) {
394-
return LLVM::CreateLoad(*builder, llvm_utils->create_gep(arr, 1));
389+
llvm::Value* SimpleCMODescriptor::get_offset(llvm::Value* arr, bool load) {
390+
llvm::Value* offset = llvm_utils->create_gep(arr, 1);
391+
if( !load ) {
392+
return offset;
393+
}
394+
return LLVM::CreateLoad(*builder, offset);
395395
}
396396

397397
llvm::Value* SimpleCMODescriptor::get_lower_bound(llvm::Value* dim_des, bool load) {
@@ -409,8 +409,12 @@ namespace LCompilers {
409409
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
410410
}
411411

412-
llvm::Value* SimpleCMODescriptor::get_stride(llvm::Value*) {
413-
return nullptr;
412+
llvm::Value* SimpleCMODescriptor::get_stride(llvm::Value* dim_des, bool load) {
413+
llvm::Value* stride = llvm_utils->create_gep(dim_des, 0);
414+
if( !load ) {
415+
return stride;
416+
}
417+
return LLVM::CreateLoad(*builder, stride);
414418
}
415419

416420
// TODO: Uncomment and implement later
@@ -421,7 +425,6 @@ namespace LCompilers {
421425
llvm::Value* arr, std::vector<llvm::Value*>& m_args,
422426
int n_args, bool check_for_bounds) {
423427
llvm::Value* dim_des_arr_ptr = LLVM::CreateLoad(*builder, llvm_utils->create_gep(arr, 2));
424-
llvm::Value* prod = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
425428
llvm::Value* idx = llvm::ConstantInt::get(context, llvm::APInt(32, 0));
426429
for( int r = 0; r < n_args; r++ ) {
427430
llvm::Value* curr_llvm_idx = m_args[r];
@@ -431,11 +434,11 @@ namespace LCompilers {
431434
if( check_for_bounds ) {
432435
// check_single_element(curr_llvm_idx, arr); TODO: To be implemented
433436
}
434-
idx = builder->CreateAdd(idx, builder->CreateMul(prod, curr_llvm_idx));
435-
llvm::Value* dim_size = LLVM::CreateLoad(*builder, llvm_utils->create_gep(dim_des_ptr, 2));
436-
prod = builder->CreateMul(prod, dim_size);
437+
llvm::Value* stride = LLVM::CreateLoad(*builder, llvm_utils->create_gep(dim_des_ptr, 0));
438+
idx = builder->CreateAdd(idx, builder->CreateMul(stride, curr_llvm_idx));
437439
}
438-
return idx;
440+
llvm::Value* offset_val = LLVM::CreateLoad(*builder, llvm_utils->create_gep(arr, 1));
441+
return builder->CreateAdd(idx, offset_val);
439442
}
440443

441444
llvm::Value* SimpleCMODescriptor::cmo_convertor_single_element_data_only(
@@ -564,12 +567,15 @@ namespace LCompilers {
564567
num_elements);
565568

566569
if( this->is_array(asr_shape_type) ) {
570+
builder->CreateStore(LLVM::CreateLoad(*builder, llvm_utils->create_gep(array, 1)),
571+
llvm_utils->create_gep(reshaped, 1));
567572
llvm::Value* n_dims = this->get_array_size(shape, nullptr, 4);
568573
llvm::Value* shape_data = LLVM::CreateLoad(*builder, this->get_pointer_to_data(shape));
569574
llvm::Value* dim_des_val = llvm_utils->create_gep(reshaped, 2);
570575
llvm::Value* dim_des_first = builder->CreateAlloca(dim_des, n_dims);
571576
builder->CreateStore(n_dims, this->get_rank(reshaped, true));
572577
builder->CreateStore(dim_des_first, dim_des_val);
578+
llvm::Value* prod = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
573579
dim_des_val = LLVM::CreateLoad(*builder, dim_des_val);
574580
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
575581
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
@@ -588,8 +594,9 @@ namespace LCompilers {
588594
llvm::Value* dim_val = llvm_utils->create_ptr_gep(dim_des_val, r_val);
589595
llvm::Value* s_val = llvm_utils->create_gep(dim_val, 0);
590596
llvm::Value* dim_size_ptr = llvm_utils->create_gep(dim_val, 2);
591-
builder->CreateStore(llvm::ConstantInt::get(context, llvm::APInt(32, 1)), s_val);
597+
builder->CreateStore(prod, s_val);
592598
llvm::Value* dim_size = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(shape_data, r_val));
599+
prod = builder->CreateMul(prod, dim_size);
593600
builder->CreateStore(dim_size, dim_size_ptr);
594601
r_val = builder->CreateAdd(r_val, llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
595602
builder->CreateStore(r_val, r);
@@ -623,14 +630,14 @@ namespace LCompilers {
623630
LLVM::CreateLoad(*builder, ptr2firstptr), llvm::MaybeAlign(),
624631
num_elements);
625632

626-
llvm::Value* src_offset_ptr = LLVM::CreateLoad(*builder, llvm_utils->create_gep(src, 1));
627-
builder->CreateStore(src_offset_ptr, llvm_utils->create_gep(dest, 1));
628633
llvm::Value* src_dim_des_val = this->get_pointer_to_dimension_descriptor_array(src, true);
629634
llvm::Value* n_dims = this->get_rank(src, false);
630635
llvm::Value* dest_dim_des_val = nullptr;
631636
if( !create_dim_des_array ) {
632637
dest_dim_des_val = this->get_pointer_to_dimension_descriptor_array(dest, true);
633638
} else {
639+
llvm::Value* src_offset_ptr = LLVM::CreateLoad(*builder, llvm_utils->create_gep(src, 1));
640+
builder->CreateStore(src_offset_ptr, llvm_utils->create_gep(dest, 1));
634641
llvm::Value* dest_dim_des_ptr = this->get_pointer_to_dimension_descriptor_array(dest, false);
635642
dest_dim_des_val = builder->CreateAlloca(dim_des, n_dims);
636643
builder->CreateStore(dest_dim_des_val, dest_dim_des_ptr);
@@ -650,24 +657,28 @@ namespace LCompilers {
650657
llvm_utils->start_new_block(loopbody);
651658
llvm::Value* r_val = LLVM::CreateLoad(*builder, r);
652659
llvm::Value* src_dim_val = llvm_utils->create_ptr_gep(src_dim_des_val, r_val);
653-
llvm::Value* src_s_val = llvm_utils->create_gep(src_dim_val, 0);
654660
llvm::Value* src_l_val = nullptr;
661+
llvm::Value* src_s_val = nullptr;
655662
if( create_dim_des_array ) {
663+
src_s_val = llvm_utils->create_gep(src_dim_val, 0);
656664
src_l_val = llvm_utils->create_gep(src_dim_val, 1);
657665
}
658666
llvm::Value* src_dim_size_ptr = llvm_utils->create_gep(src_dim_val, 2);
659667
llvm::Value* dest_dim_val = llvm_utils->create_ptr_gep(dest_dim_des_val, r_val);
660-
llvm::Value* dest_s_val = llvm_utils->create_gep(dest_dim_val, 0);
668+
llvm::Value* dest_s_val = nullptr;
661669
llvm::Value* dest_l_val = nullptr;
670+
llvm::Value* dest_dim_size_ptr = nullptr;
662671
if( create_dim_des_array ) {
672+
dest_s_val = llvm_utils->create_gep(dest_dim_val, 0);
663673
dest_l_val = llvm_utils->create_gep(dest_dim_val, 1);
674+
dest_dim_size_ptr = llvm_utils->create_gep(dest_dim_val, 2);
664675
}
665-
llvm::Value* dest_dim_size_ptr = llvm_utils->create_gep(dest_dim_val, 2);
666-
builder->CreateStore(LLVM::CreateLoad(*builder, src_s_val), dest_s_val);
676+
667677
if( create_dim_des_array ) {
668678
builder->CreateStore(LLVM::CreateLoad(*builder, src_l_val), dest_l_val);
679+
builder->CreateStore(LLVM::CreateLoad(*builder, src_s_val), dest_s_val);
680+
builder->CreateStore(LLVM::CreateLoad(*builder, src_dim_size_ptr), dest_dim_size_ptr);
669681
}
670-
builder->CreateStore(LLVM::CreateLoad(*builder, src_dim_size_ptr), dest_dim_size_ptr);
671682
r_val = builder->CreateAdd(r_val, llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
672683
builder->CreateStore(r_val, r);
673684
builder->CreateBr(loophead);

src/libasr/codegen/llvm_array_utils.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ namespace LCompilers {
183183
* implemented by current class).
184184
*/
185185
virtual
186-
llvm::Value* get_offset(llvm::Value* dim_des) = 0;
186+
llvm::Value* get_offset(llvm::Value* dim_des, bool load=true) = 0;
187187

188188
/*
189189
* Returns lower bound in the input
@@ -207,7 +207,7 @@ namespace LCompilers {
207207
* implemented by current class.
208208
*/
209209
virtual
210-
llvm::Value* get_stride(llvm::Value* dim_des) = 0;
210+
llvm::Value* get_stride(llvm::Value* dim_des, bool load=true) = 0;
211211

212212
/*
213213
* Returns dimension size in the input
@@ -370,7 +370,7 @@ namespace LCompilers {
370370
void set_rank(llvm::Value* arr, llvm::Value* rank);
371371

372372
virtual
373-
llvm::Value* get_offset(llvm::Value* dim_des);
373+
llvm::Value* get_offset(llvm::Value* dim_des, bool load=true);
374374

375375
virtual
376376
llvm::Value* get_lower_bound(llvm::Value* dim_des, bool load=true);
@@ -390,7 +390,7 @@ namespace LCompilers {
390390
llvm::Value* dim);
391391

392392
virtual
393-
llvm::Value* get_stride(llvm::Value* dim_des);
393+
llvm::Value* get_stride(llvm::Value* dim_des, bool load=true);
394394

395395
virtual
396396
llvm::Value* get_single_element(llvm::Value* array,

0 commit comments

Comments
 (0)