forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtensor_layout.cpp
More file actions
55 lines (49 loc) · 1.56 KB
/
tensor_layout.cpp
File metadata and controls
55 lines (49 loc) · 1.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <c10/util/irange.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/core/span.h>
#include <executorch/runtime/core/tensor_layout.h>
namespace executorch {
namespace ET_RUNTIME_NAMESPACE {
namespace {
Result<size_t> calculate_nbytes(
const Span<const int32_t>& sizes,
const executorch::aten::ScalarType& scalar_type) {
ssize_t n = 1;
for (const auto i : c10::irange(sizes.size())) {
if (sizes[i] < 0) {
return Error::InvalidArgument;
}
n *= sizes[i];
}
// Use the full namespace to disambiguate from c10::elementSize.
return n * executorch::runtime::elementSize(scalar_type);
}
} // namespace
Result<const TensorLayout> TensorLayout::create(
Span<const int32_t> sizes,
Span<const uint8_t> dim_order,
executorch::aten::ScalarType scalar_type) {
auto nbytes = calculate_nbytes(sizes, scalar_type);
if (!nbytes.ok()) {
return nbytes.error();
}
if (dim_order.size() != sizes.size()) {
return Error::InvalidArgument;
}
for (const auto i : c10::irange(dim_order.size())) {
if (dim_order[i] >= sizes.size()) {
return Error::InvalidArgument;
}
}
return TensorLayout(sizes, dim_order, scalar_type, nbytes.get());
}
} // namespace ET_RUNTIME_NAMESPACE
} // namespace executorch