Skip to content

Commit 06bb6ff

Browse files
author
ssjia
committed
Update on "[ET-VK] Introduce BufferMetadata GLSL struct to abstract tensor layout"
Differential Revision: [D80800082](https://our.internmc.facebook.com/intern/diff/D80800082) [ghstack-poisoned]
2 parents 6d5cc9f + 6173519 commit 06bb6ff

File tree

30 files changed

+572
-252
lines changed

30 files changed

+572
-252
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -295,16 +295,17 @@ int32_t create_hashed_layout(
295295
}
296296

297297
size_t calculate_max_ubo_nbytes(
298-
const size_t nbytes_per_ubo,
298+
const size_t min_nbytes_per_ubo,
299299
const utils::StorageType storage_type) {
300-
// For texture backed tensors, the metadata fields needed are:
301-
// sizes, logical limits
302-
size_t max_metadata_field_count = 2u;
300+
size_t ivec4_ubo_nbytes = utils::align_up(size_t(16), min_nbytes_per_ubo);
301+
size_t uvec3_ubo_nbytes = utils::align_up(size_t(12), min_nbytes_per_ubo);
302+
size_t int32_ubo_nbytes = utils::align_up(size_t(4), min_nbytes_per_ubo);
303303
if (storage_type == utils::kBuffer) {
304304
// sizes, strides, dim order, numel
305-
max_metadata_field_count = 4u;
305+
return 3 * ivec4_ubo_nbytes + int32_ubo_nbytes;
306306
}
307-
return max_metadata_field_count * nbytes_per_ubo;
307+
// sizes, logical limits
308+
return ivec4_ubo_nbytes + uvec3_ubo_nbytes;
308309
}
309310

310311
//
@@ -835,7 +836,7 @@ const vkapi::BufferBindInfo vTensor::numel_ubo() {
835836
}
836837

837838
const vkapi::BufferBindInfo vTensor::buffer_meta_ubo() {
838-
size_t ubo_nbytes = std::max(sizeof(BufferMetadata), min_nbytes_per_ubo_);
839+
size_t ubo_nbytes = sizeof(BufferMetadata);
839840
if (!buffer_meta_.buffer()) {
840841
BufferMetadata data(sizes_, dim_order_, strides_, numel_);
841842
buffer_meta_ = ParamsBuffer(storage_->context_, data);

backends/vulkan/runtime/api/containers/Tensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ class vTensor final {
548548
if (!uniforms_.buffer()) {
549549
uniforms_ = ParamsBuffer(storage_->context_, max_ubo_nbytes_, true);
550550
}
551-
size_t ubo_nbytes = std::max(sizeof(T), min_nbytes_per_ubo_);
551+
size_t ubo_nbytes = utils::align_up(sizeof(data), min_nbytes_per_ubo_);
552552
if (*param_buffer_offset == kUniformOffsetUnset) {
553553
VK_CHECK_COND(
554554
(uniforms_size_ + ubo_nbytes) <= max_ubo_nbytes_,

examples/models/llava/runner/llava_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ET_EXPERIMENTAL LlavaRunner {
4242
const float temperature = 0.8f)
4343
: temperature_(temperature),
4444
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
45-
io_manager_(std::make_unique<IOManager>()),
45+
io_manager_(std::make_unique<IOManager>(*module_)),
4646
tokenizer_path_(tokenizer_path) {
4747
ET_LOG(
4848
Info,

examples/models/qwen3/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ python -m extension.llm.export.export_llm \
4545
### Example run
4646
With ExecuTorch pybindings:
4747
```
48-
python -m examples.models.llama.runner.native
48+
python -m examples.models.llama.runner.native \
4949
--model qwen3_0_6b \
5050
--pte qwen3_0_6b.pte \
5151
--tokenizer ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json \
@@ -59,9 +59,9 @@ python -m examples.models.llama.runner.native
5959

6060
With ExecuTorch's sample c++ runner (see the Llama README's [Step 3: Run on your computer to validate](../llama/README.md#step-3-run-on-your-computer-to-validate) to build the runner):
6161
```
62-
cmake-out/examples/models/llama/llama_main
63-
--model_path qwen3_0_6b.pte
64-
--tokenizer_path ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json
62+
cmake-out/examples/models/llama/llama_main \
63+
--model_path qwen3_0_6b.pte \
64+
--tokenizer_path ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json \
6565
--prompt="Who is the president of the US?"
6666
```
6767

extension/android/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ non_fbcode_target(_kind = fb_android_library,
1010
"executorch_android/src/main/java/org/pytorch/executorch/DType.java",
1111
"executorch_android/src/main/java/org/pytorch/executorch/EValue.java",
1212
"executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java",
13+
"executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java",
1314
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
1415
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
1516
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",

extension/android/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ executorch_target_link_options_shared_lib(executorch)
7171

7272
add_library(
7373
executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp
74+
jni/jni_helper.cpp
7475
)
7576

7677
set(link_libraries)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
import java.util.Collections;
12+
import java.util.HashMap;
13+
import java.util.Map;
14+
15+
public class ExecutorchRuntimeException extends RuntimeException {
16+
// Error code constants - keep in sync with runtime/core/error.h
17+
// System errors
18+
public static final int OK = 0x00;
19+
public static final int INTERNAL = 0x01;
20+
public static final int INVALID_STATE = 0x02;
21+
public static final int END_OF_METHOD = 0x03;
22+
23+
// Logical errors
24+
public static final int NOT_SUPPORTED = 0x10;
25+
public static final int NOT_IMPLEMENTED = 0x11;
26+
public static final int INVALID_ARGUMENT = 0x12;
27+
public static final int INVALID_TYPE = 0x13;
28+
public static final int OPERATOR_MISSING = 0x14;
29+
public static final int REGISTRATION_EXCEEDING_MAX_KERNELS = 0x15;
30+
public static final int REGISTRATION_ALREADY_REGISTERED = 0x16;
31+
32+
// Resource errors
33+
public static final int NOT_FOUND = 0x20;
34+
public static final int MEMORY_ALLOCATION_FAILED = 0x21;
35+
public static final int ACCESS_FAILED = 0x22;
36+
public static final int INVALID_PROGRAM = 0x23;
37+
public static final int INVALID_EXTERNAL_DATA = 0x24;
38+
public static final int OUT_OF_RESOURCES = 0x25;
39+
40+
// Delegate errors
41+
public static final int DELEGATE_INVALID_COMPATIBILITY = 0x30;
42+
public static final int DELEGATE_MEMORY_ALLOCATION_FAILED = 0x31;
43+
public static final int DELEGATE_INVALID_HANDLE = 0x32;
44+
45+
private static final Map<Integer, String> ERROR_CODE_MESSAGES;
46+
47+
static {
48+
Map<Integer, String> map = new HashMap<>();
49+
50+
// System errors
51+
map.put(OK, "Operation successful");
52+
map.put(INTERNAL, "Internal error");
53+
map.put(INVALID_STATE, "Invalid state");
54+
map.put(END_OF_METHOD, "End of method reached");
55+
// Logical errors
56+
map.put(NOT_SUPPORTED, "Operation not supported");
57+
map.put(NOT_IMPLEMENTED, "Operation not implemented");
58+
map.put(INVALID_ARGUMENT, "Invalid argument");
59+
map.put(INVALID_TYPE, "Invalid type");
60+
map.put(OPERATOR_MISSING, "Operator missing");
61+
map.put(REGISTRATION_EXCEEDING_MAX_KERNELS, "Exceeded max kernels");
62+
map.put(REGISTRATION_ALREADY_REGISTERED, "Kernel already registered");
63+
// Resource errors
64+
map.put(NOT_FOUND, "Resource not found");
65+
map.put(MEMORY_ALLOCATION_FAILED, "Memory allocation failed");
66+
map.put(ACCESS_FAILED, "Access failed");
67+
map.put(INVALID_PROGRAM, "Invalid program");
68+
map.put(INVALID_EXTERNAL_DATA, "Invalid external data");
69+
map.put(OUT_OF_RESOURCES, "Out of resources");
70+
// Delegate errors
71+
map.put(DELEGATE_INVALID_COMPATIBILITY, "Delegate invalid compatibility");
72+
map.put(DELEGATE_MEMORY_ALLOCATION_FAILED, "Delegate memory allocation failed");
73+
map.put(DELEGATE_INVALID_HANDLE, "Delegate invalid handle");
74+
ERROR_CODE_MESSAGES = Collections.unmodifiableMap(map);
75+
}
76+
77+
static class ErrorHelper {
78+
static String formatMessage(int errorCode, String details) {
79+
String baseMessage = ERROR_CODE_MESSAGES.get(errorCode);
80+
if (baseMessage == null) {
81+
baseMessage = "Unknown error code 0x" + Integer.toHexString(errorCode);
82+
}
83+
return "[Executorch Error 0x"
84+
+ Integer.toHexString(errorCode)
85+
+ "] "
86+
+ baseMessage
87+
+ ": "
88+
+ details;
89+
}
90+
}
91+
92+
private final int errorCode;
93+
94+
public ExecutorchRuntimeException(int errorCode, String details) {
95+
super(ErrorHelper.formatMessage(errorCode, details));
96+
this.errorCode = errorCode;
97+
}
98+
99+
public int getErrorCode() {
100+
return errorCode;
101+
}
102+
103+
// Idiomatic Java exception for invalid arguments.
104+
public static class ExecutorchInvalidArgumentException extends IllegalArgumentException {
105+
private final int errorCode = INVALID_ARGUMENT;
106+
107+
public ExecutorchInvalidArgumentException(String details) {
108+
super(ErrorHelper.formatMessage(INVALID_ARGUMENT, details));
109+
}
110+
111+
public int getErrorCode() {
112+
return errorCode;
113+
}
114+
}
115+
116+
// Factory method to create an exception of the appropriate subclass.
117+
public static RuntimeException makeExecutorchException(int errorCode, String details) {
118+
switch (errorCode) {
119+
case INVALID_ARGUMENT:
120+
return new ExecutorchInvalidArgumentException(details);
121+
default:
122+
return new ExecutorchRuntimeException(errorCode, details);
123+
}
124+
}
125+
}

extension/android/jni/BUCK

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@ load(":build_defs.bzl", "ET_JNI_COMPILER_FLAGS")
77

88
oncall("executorch")
99

10+
# Define the common JNI source files
11+
shared_srcs = [
12+
"jni_layer.cpp",
13+
"jni_layer_runtime.cpp",
14+
"jni_helper.cpp",
15+
"log.cpp",
16+
]
17+
1018
non_fbcode_target(_kind = executorch_generated_lib,
1119
name = "generated_op_lib_optimized",
1220
custom_ops_aten_kernel_deps = [
@@ -28,7 +36,7 @@ non_fbcode_target(_kind = executorch_generated_lib,
2836

2937
non_fbcode_target(_kind = fb_android_cxx_library,
3038
name = "executorch_jni",
31-
srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp"],
39+
srcs = shared_srcs,
3240
allow_jni_merging = False,
3341
compiler_flags = ET_JNI_COMPILER_FLAGS,
3442
soname = "libexecutorch.$(ext)",
@@ -49,7 +57,7 @@ non_fbcode_target(_kind = fb_android_cxx_library,
4957

5058
non_fbcode_target(_kind = fb_android_cxx_library,
5159
name = "executorch_jni_full",
52-
srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp"],
60+
srcs = shared_srcs,
5361
allow_jni_merging = False,
5462
compiler_flags = ET_JNI_COMPILER_FLAGS,
5563
soname = "libexecutorch.$(ext)",
@@ -71,7 +79,7 @@ non_fbcode_target(_kind = fb_android_cxx_library,
7179

7280
non_fbcode_target(_kind = fb_android_cxx_library,
7381
name = "executorch_training_jni",
74-
srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp", "jni_layer_training.cpp"],
82+
srcs = shared_srcs + ["jni_layer_training.cpp"],
7583
allow_jni_merging = False,
7684
compiler_flags = ET_JNI_COMPILER_FLAGS + [
7785
"-DEXECUTORCH_BUILD_EXTENSION_TRAINING",
@@ -98,11 +106,9 @@ non_fbcode_target(_kind = fb_android_cxx_library,
98106

99107
non_fbcode_target(_kind = fb_android_cxx_library,
100108
name = "executorch_llama_jni",
101-
srcs = [
102-
"jni_layer.cpp",
103-
"jni_layer_llama.cpp",
104-
"jni_layer_runtime.cpp",
105-
],
109+
exclude_files = ["log.cpp"]
110+
shared_srcs_filtered = [f for f in shared_srcs if f not in exclude_files]
111+
srcs = shared_srcs_filtered + ["jni_layer_llama.cpp"]
106112
allow_jni_merging = False,
107113
compiler_flags = ET_JNI_COMPILER_FLAGS + [
108114
"-DEXECUTORCH_BUILD_LLAMA_JNI",
@@ -145,6 +151,10 @@ runtime.export_file(
145151
name = "jni_layer_runtime.cpp",
146152
)
147153

154+
runtime.export_file(
155+
name = "jni_helper.cpp",
156+
)
157+
148158
runtime.cxx_library(
149159
name = "jni_headers",
150160
exported_headers = [

extension/android/jni/jni_helper.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "jni_helper.h"
10+
11+
namespace executorch::jni_helper {
12+
13+
void throwExecutorchException(uint32_t errorCode, const std::string& details) {
14+
// Get the current JNI environment
15+
auto env = facebook::jni::Environment::current();
16+
17+
// Find the Java ExecutorchRuntimeException class
18+
static auto exceptionClass = facebook::jni::findClassLocal(
19+
"org/pytorch/executorch/ExecutorchRuntimeException");
20+
21+
// Find the static factory method: makeExecutorchException(int, String)
22+
static auto makeExceptionMethod = exceptionClass->getStaticMethod<
23+
facebook::jni::local_ref<facebook::jni::JThrowable>(
24+
int, facebook::jni::alias_ref<facebook::jni::JString>)>(
25+
"makeExecutorchException",
26+
"(ILjava/lang/String;)Lorg/pytorch/executorch/ExecutorchRuntimeException;");
27+
28+
auto jDetails = facebook::jni::make_jstring(details);
29+
// Call the factory method to create the exception object
30+
auto exception = makeExceptionMethod(exceptionClass, errorCode, jDetails);
31+
facebook::jni::throwNewJavaException(exception.get());
32+
}
33+
34+
} // namespace executorch::jni_helper

extension/android/jni/jni_helper.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <fbjni/fbjni.h>
12+
#include <string>
13+
14+
namespace executorch::jni_helper {
15+
16+
/**
17+
* Throws a Java ExecutorchRuntimeException corresponding to the given error
18+
* code and details. Uses the Java factory method
19+
* ExecutorchRuntimeException.makeExecutorchException(int, String).
20+
*
21+
* @param errorCode The error code from the C++ Executorch runtime.
22+
* @param details Additional details to include in the exception message.
23+
*/
24+
void throwExecutorchException(uint32_t errorCode, const std::string& details);
25+
26+
} // namespace executorch::jni_helper

0 commit comments

Comments
 (0)