Skip to content

Commit 33e2f6a

Browse files
authored
[WebNN EP] Support external data (microsoft#22263)
### Description This PR introduces support for registering external data inside WebNN EP. ### Motivation and Context - The WebNN EP needs to register the initializers at graph compilation stage, for initializers from external data, it can't leverage the general external data loader framework because the graph compilation of WebNN EP is executed before external data loader called. - Exposes the `utils::GetExternalDataInfo`, it is useful for WebNN EP to read the external tensor's infomation. - Define a new `registerMLConstant` in JSEP to create WebNN constants from external data in WebNN backend, with the info of tensor as parameters, as well as the `Module.MountedFiles`, which holds all preloaded external files.
1 parent ffaddea commit 33e2f6a

File tree

6 files changed

+178
-104
lines changed

6 files changed

+178
-104
lines changed

js/web/lib/wasm/jsep/backend-webnn.ts

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,69 @@ export class WebNNBackend {
163163
return id;
164164
}
165165

166+
// Register WebNN Constant operands from external data.
167+
public registerMLConstant(
168+
externalFilePath: string,
169+
dataOffset: number,
170+
dataLength: number,
171+
builder: MLGraphBuilder,
172+
desc: MLOperandDescriptor,
173+
mountedFiles: Map<string, Uint8Array> | undefined,
174+
): MLOperand {
175+
// If available, "Module.MountedFiles" is a Map for all preloaded files.
176+
if (!mountedFiles) {
177+
throw new Error('External mounted files are not available.');
178+
}
179+
180+
let filePath = externalFilePath;
181+
if (externalFilePath.startsWith('./')) {
182+
filePath = externalFilePath.substring(2);
183+
}
184+
const fileData = mountedFiles.get(filePath);
185+
if (!fileData) {
186+
throw new Error(`File with name ${filePath} not found in preloaded files.`);
187+
}
188+
189+
if (dataOffset + dataLength > fileData.byteLength) {
190+
throw new Error('Out of bounds: data offset and length exceed the external file data size.');
191+
}
192+
193+
const buffer = fileData.slice(dataOffset, dataOffset + dataLength).buffer;
194+
let bufferView: ArrayBufferView;
195+
switch (desc.dataType) {
196+
case 'float32':
197+
bufferView = new Float32Array(buffer);
198+
break;
199+
case 'float16':
200+
bufferView = new Uint16Array(buffer);
201+
break;
202+
case 'int32':
203+
bufferView = new Int32Array(buffer);
204+
break;
205+
case 'uint32':
206+
bufferView = new Uint32Array(buffer);
207+
break;
208+
case 'int64':
209+
bufferView = new BigInt64Array(buffer);
210+
break;
211+
case 'uint64':
212+
bufferView = new BigUint64Array(buffer);
213+
break;
214+
case 'int8':
215+
bufferView = new Int8Array(buffer);
216+
break;
217+
case 'uint8':
218+
bufferView = new Uint8Array(buffer);
219+
break;
220+
default:
221+
throw new Error(`Unsupported data type: ${desc.dataType} in creating WebNN Constant from external data.`);
222+
}
223+
224+
LOG_DEBUG('verbose', () => `[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}}`);
225+
226+
return builder.constant(desc, bufferView);
227+
}
228+
166229
public flush(): void {
167230
// Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations.
168231
}

onnxruntime/core/framework/tensorprotoutils.cc

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -165,37 +165,6 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t
165165
DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2)
166166
DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2)
167167

168-
static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
169-
const std::filesystem::path& tensor_proto_dir,
170-
std::basic_string<ORTCHAR_T>& external_file_path,
171-
onnxruntime::FileOffsetType& file_offset,
172-
SafeInt<size_t>& tensor_byte_size) {
173-
ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto),
174-
"Tensor does not have external data to read from.");
175-
176-
ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto),
177-
"External data type cannot be UNDEFINED or STRING.");
178-
179-
std::unique_ptr<onnxruntime::ExternalDataInfo> external_data_info;
180-
ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));
181-
182-
const auto& location = external_data_info->GetRelPath();
183-
184-
external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location)
185-
: (tensor_proto_dir / location);
186-
187-
ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size));
188-
const size_t external_data_length = external_data_info->GetLength();
189-
ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size,
190-
"TensorProto: ", tensor_proto.name(),
191-
" external data size mismatch. Computed size: ", *&tensor_byte_size,
192-
", external_data.length: ", external_data_length);
193-
194-
file_offset = external_data_info->GetOffset();
195-
196-
return Status::OK();
197-
}
198-
199168
// Read external data for tensor in unint8_t* form and return Status::OK() if the data is read successfully.
200169
// Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr
201170
// then uses the current directory instead.
@@ -261,6 +230,37 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo
261230

262231
namespace utils {
263232

233+
Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
234+
const std::filesystem::path& tensor_proto_dir,
235+
std::basic_string<ORTCHAR_T>& external_file_path,
236+
onnxruntime::FileOffsetType& file_offset,
237+
SafeInt<size_t>& tensor_byte_size) {
238+
ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto),
239+
"Tensor does not have external data to read from.");
240+
241+
ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto),
242+
"External data type cannot be UNDEFINED or STRING.");
243+
244+
std::unique_ptr<onnxruntime::ExternalDataInfo> external_data_info;
245+
ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));
246+
247+
const auto& location = external_data_info->GetRelPath();
248+
249+
external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location)
250+
: (tensor_proto_dir / location);
251+
252+
ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size));
253+
const size_t external_data_length = external_data_info->GetLength();
254+
ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size,
255+
"TensorProto: ", tensor_proto.name(),
256+
" external data size mismatch. Computed size: ", *&tensor_byte_size,
257+
", external_data.length: ", external_data_length);
258+
259+
file_offset = external_data_info->GetOffset();
260+
261+
return Status::OK();
262+
}
263+
264264
void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) {
265265
tensor_proto.set_raw_data(std::move(param));
266266
}

onnxruntime/core/framework/tensorprotoutils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@
2323

2424
namespace onnxruntime {
2525
namespace utils {
26+
/**
27+
* This function is used to get the external data info from the given tensor proto.
28+
* @param tensor_proto given initializer tensor
29+
* @param tensor_proto_dir directory of the tensor proto file
30+
* @param external_file_path output external file path
31+
* @param file_offset output tensor offset
32+
* @param tensor_byte_size output tensor byte size
33+
* @returns Status::OK() if the function is executed successfully
34+
*/
35+
Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
36+
const std::filesystem::path& tensor_proto_dir,
37+
std::basic_string<ORTCHAR_T>& external_file_path,
38+
onnxruntime::FileOffsetType& file_offset,
39+
SafeInt<size_t>& tensor_byte_size);
2640
/**
2741
* This function is used to convert the endianess of Tensor data.
2842
* Mostly, will be used in big endian system to support the model file

onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,6 @@
1212

1313
namespace onnxruntime {
1414
namespace webnn {
15-
16-
// Shared functions.
17-
bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node,
18-
const logging::Logger& logger) {
19-
for (const auto* node_arg : node.InputDefs()) {
20-
const auto& input_name(node_arg->Name());
21-
if (!Contains(initializers, input_name))
22-
continue;
23-
24-
const auto& tensor = *initializers.at(input_name);
25-
if (tensor.has_data_location() &&
26-
tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) {
27-
LOGS(logger, VERBOSE) << "Initializer [" << input_name
28-
<< "] with external data location are not currently supported";
29-
return true;
30-
}
31-
}
32-
33-
return false;
34-
}
35-
3615
// Add operator related.
3716

3817
Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
@@ -58,10 +37,6 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
5837
if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
5938
return false;
6039

61-
// We do not support external initializers for now.
62-
if (HasExternalInitializer(initializers, node, logger))
63-
return false;
64-
6540
if (!HasSupportedOpSet(node, logger))
6641
return false;
6742

onnxruntime/core/providers/webnn/builders/model_builder.cc

Lines changed: 65 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -112,56 +112,73 @@ Status ModelBuilder::RegisterInitializers() {
112112
auto num_elements = SafeInt<size_t>(Product(shape));
113113
emscripten::val view = emscripten::val::undefined();
114114
std::byte* tensor_ptr = nullptr;
115-
if (tensor.has_raw_data()) {
116-
tensor_ptr = reinterpret_cast<std::byte*>(const_cast<char*>(tensor.raw_data().c_str()));
115+
116+
if (utils::HasExternalData(tensor)) {
117+
// Create WebNN Constant from external data.
118+
std::basic_string<ORTCHAR_T> external_file_path;
119+
onnxruntime::FileOffsetType data_offset;
120+
SafeInt<size_t> tensor_byte_size;
121+
ORT_RETURN_IF_ERROR(utils::GetExternalDataInfo(
122+
tensor, graph_viewer_.ModelPath(), external_file_path, data_offset, tensor_byte_size));
123+
124+
auto jsepRegisterMLConstant = emscripten::val::module_property("jsepRegisterMLConstant");
125+
operand = jsepRegisterMLConstant(emscripten::val(external_file_path),
126+
static_cast<int32_t>(data_offset),
127+
static_cast<int32_t>(tensor_byte_size),
128+
wnn_builder_,
129+
desc);
117130
} else {
118-
// Store temporary unpacked_tensor.
119-
unpacked_tensors_.push_back({});
120-
std::vector<uint8_t>& unpacked_tensor = unpacked_tensors_.back();
121-
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor));
122-
tensor_ptr = reinterpret_cast<std::byte*>(unpacked_tensor.data());
123-
}
124-
switch (data_type) {
125-
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
126-
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
127-
view = emscripten::val{emscripten::typed_memory_view(num_elements,
128-
reinterpret_cast<uint8_t*>(tensor_ptr))};
129-
break;
130-
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
131-
view = emscripten::val{emscripten::typed_memory_view(num_elements,
132-
reinterpret_cast<int8_t*>(tensor_ptr))};
133-
break;
134-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
135-
view = emscripten::val{emscripten::typed_memory_view(num_elements,
136-
reinterpret_cast<uint16_t*>(tensor_ptr))};
137-
break;
138-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
139-
view = emscripten::val{emscripten::typed_memory_view(num_elements,
140-
reinterpret_cast<float*>(tensor_ptr))};
141-
break;
142-
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
143-
view = emscripten::val{emscripten::typed_memory_view(num_elements,
144-
reinterpret_cast<int32_t*>(tensor_ptr))};
145-
break;
146-
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
147-
view = emscripten::val{emscripten::typed_memory_view(num_elements,
148-
reinterpret_cast<int64_t*>(tensor_ptr))};
149-
break;
150-
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
151-
view = emscripten::val{emscripten::typed_memory_view(num_elements,
152-
reinterpret_cast<uint32_t*>(tensor_ptr))};
153-
break;
154-
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
155-
view = emscripten::val{emscripten::typed_memory_view(num_elements,
156-
reinterpret_cast<uint64_t*>(tensor_ptr))};
157-
break;
158-
default:
159-
break;
131+
if (tensor.has_raw_data()) {
132+
tensor_ptr = reinterpret_cast<std::byte*>(const_cast<char*>(tensor.raw_data().c_str()));
133+
} else {
134+
// Store temporary unpacked_tensor.
135+
unpacked_tensors_.push_back({});
136+
std::vector<uint8_t>& unpacked_tensor = unpacked_tensors_.back();
137+
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor));
138+
tensor_ptr = reinterpret_cast<std::byte*>(unpacked_tensor.data());
139+
}
140+
switch (data_type) {
141+
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
142+
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
143+
view = emscripten::val{emscripten::typed_memory_view(num_elements,
144+
reinterpret_cast<uint8_t*>(tensor_ptr))};
145+
break;
146+
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
147+
view = emscripten::val{emscripten::typed_memory_view(num_elements,
148+
reinterpret_cast<int8_t*>(tensor_ptr))};
149+
break;
150+
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
151+
view = emscripten::val{emscripten::typed_memory_view(num_elements,
152+
reinterpret_cast<uint16_t*>(tensor_ptr))};
153+
break;
154+
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
155+
view = emscripten::val{emscripten::typed_memory_view(num_elements,
156+
reinterpret_cast<float*>(tensor_ptr))};
157+
break;
158+
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
159+
view = emscripten::val{emscripten::typed_memory_view(num_elements,
160+
reinterpret_cast<int32_t*>(tensor_ptr))};
161+
break;
162+
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
163+
view = emscripten::val{emscripten::typed_memory_view(num_elements,
164+
reinterpret_cast<int64_t*>(tensor_ptr))};
165+
break;
166+
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
167+
view = emscripten::val{emscripten::typed_memory_view(num_elements,
168+
reinterpret_cast<uint32_t*>(tensor_ptr))};
169+
break;
170+
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
171+
view = emscripten::val{emscripten::typed_memory_view(num_elements,
172+
reinterpret_cast<uint64_t*>(tensor_ptr))};
173+
break;
174+
default:
175+
break;
176+
}
177+
178+
// Wasm memory grow will cause all array buffers reallocation, which will be treated as detached
179+
// buffers in JS side. Simply create a copy to fix it.
180+
operand = wnn_builder_.call<emscripten::val>("constant", desc, view.call<emscripten::val>("slice"));
160181
}
161-
162-
// Wasm memory grow will cause all array buffers reallocation, which will be treated as detached
163-
// buffers in JS side. Simply create a copy to fix it.
164-
operand = wnn_builder_.call<emscripten::val>("constant", desc, view.call<emscripten::val>("slice"));
165182
} else {
166183
// TODO: support other type.
167184
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,

onnxruntime/wasm/pre-jsep.js

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,5 +235,10 @@ Module['jsepInit'] = (name, params) => {
235235
Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => {
236236
return backend['registerMLTensor'](tensor, dataType, shape);
237237
}
238+
239+
Module.jsepRegisterMLConstant = (externalFilePath, dataOffset, dataLength, builder, desc) => {
240+
return backend['registerMLConstant'](
241+
externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles);
242+
}
238243
}
239244
};

0 commit comments

Comments
 (0)