Skip to content

Commit f2adbff

Browse files
Stephen Jiafacebook-github-bot
authored andcommitted
[Metal] Do not use read/write textures in concat shaders (pytorch#61074)
Summary: Pull Request resolved: pytorch#61074 `read_write` textures are not available on some devices, such as iPhone 7. This prevents the concat op from functioning on those devices. This diff rewrites the concat shaders such that they do not depend on `read_write` textures. Test Plan: Test on device: run squeezenet and/or the operator tests ``` arc focus2 pp-ios ``` Test on Mac ``` buck test pp-macos ``` Test specifically on iPhone7, either device or simulator. Reviewed By: xta0 Differential Revision: D29501656 fbshipit-source-id: de4a059953ab4b0abf38b6ecb3f665323dcdbea1
1 parent 80bdfd6 commit f2adbff

File tree

2 files changed

+157
-214
lines changed

2 files changed

+157
-214
lines changed

aten/src/ATen/native/metal/MetalShaders.h

Lines changed: 67 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -271,203 +271,97 @@ kernel void copy_offset_nonarray(texture2d<half, access::read> in[[texture(0)]],
271271
out.write(in.read(gid_), gid_, gid.z + offset_buf[0]);
272272
}
273273
274-
kernel void append_features_off0(texture2d_array<half, access::read> in[[texture(0)]],
275-
texture2d_array<half, access::read_write> out[[texture(1)]],
276-
constant ushort* offset_buf[[buffer(0)]],
277-
ushort3 gid[[thread_position_in_grid]]) {
278-
if (gid.x >= out.get_width() || gid.y >= out.get_height() || gid.z >= offset_buf[4]) {
279-
return;
280-
}
281-
ushort2 gid_ = gid.xy;
282-
283-
ushort batch = gid.z / offset_buf[0];
284-
ushort feature = gid.z % offset_buf[0];
285-
ushort outz = batch * offset_buf[1] + offset_buf[2] + feature;
286-
ushort inz = batch * offset_buf[3] + feature;
287-
288-
half4 intex1 = in.read(gid_, inz);
289-
half4 outtex = intex1;
290-
291-
out.write(outtex, gid_, outz);
292-
}
293-
294-
kernel void append_features_off0_nonarray(texture2d<half, access::read> in[[texture(0)]],
295-
texture2d_array<half, access::read_write> out[[texture(1)]],
296-
constant ushort* offset_buf[[buffer(0)]],
297-
ushort3 gid[[thread_position_in_grid]]) {
298-
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
299-
return;
300-
}
274+
constant bool store_features_out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
275+
constant bool store_features_out_is_tex = !store_features_out_is_arr;
276+
kernel void store_features(texture2d_array<half, access::read> in[[texture(0)]],
277+
texture2d<half, access::write> out_tex[[texture(1), function_constant(store_features_out_is_tex)]],
278+
texture2d_array<half, access::write> out_arr[[texture(1), function_constant(store_features_out_is_arr)]],
279+
constant ushort* offset_buf[[buffer(0)]],
280+
ushort3 gid[[thread_position_in_grid]]) {
301281
ushort2 gid_ = gid.xy;
302-
out.write(in.read(gid_), gid_, offset_buf[2]);
303-
}
304-
305-
kernel void append_features_off1(texture2d_array<half, access::read> in[[texture(0)]],
306-
texture2d_array<half, access::read_write> out[[texture(1)]],
307-
constant ushort* offset_buf[[buffer(0)]],
308-
ushort3 gid[[thread_position_in_grid]]) {
309-
if (gid.x >= out.get_width() || gid.y >= out.get_height() || gid.z >= offset_buf[4]) {
310-
return;
311-
}
282+
if (store_features_out_is_arr)
283+
out_arr.write(in.read(gid_, gid.z * offset_buf[1] + offset_buf[0]), gid_, gid.z);
284+
else
285+
out_tex.write(in.read(gid_, gid.z * offset_buf[1] + offset_buf[0]), gid_);
286+
}
287+
288+
constant bool append_features_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
289+
constant bool append_features_in_is_tex = !append_features_in_is_arr;
290+
kernel void append_features(texture2d<half, access::read> in_tex[[texture(0), function_constant(append_features_in_is_tex)]],
291+
texture2d_array<half, access::read> in_arr[[texture(0), function_constant(append_features_in_is_arr)]],
292+
texture2d_array<half, access::write> out[[texture(1)]],
293+
constant ushort* offset_buf[[buffer(0)]],
294+
ushort3 gid[[thread_position_in_grid]]) {
312295
ushort2 gid_ = gid.xy;
313296
314297
ushort batch = gid.z / offset_buf[0];
315298
ushort feature = gid.z % offset_buf[0];
316299
ushort outz = batch * offset_buf[1] + offset_buf[2] + feature;
317300
ushort inz = batch * offset_buf[3] + feature;
318301
319-
half4 outtex = out.read(gid_, outz);
320-
half4 intex1 = in.read(gid_, inz);
321-
if (feature == 0) {
322-
outtex.y = intex1.x;
323-
outtex.z = intex1.y;
324-
outtex.w = intex1.z;
325-
out.write(outtex, gid_, outz);
326-
return;
327-
}
328-
half4 intex0 = in.read(gid_, inz-1);
329-
outtex.x = intex0.w;
330-
outtex.y = intex1.x;
331-
outtex.z = intex1.y;
332-
outtex.w = intex1.z;
333-
334-
out.write(outtex, gid_, outz);
335-
}
336-
337-
kernel void append_features_off1_nonarray(texture2d<half, access::read> in[[texture(0)]],
338-
texture2d_array<half, access::read_write> out[[texture(1)]],
339-
constant ushort* offset_buf[[buffer(0)]],
340-
ushort3 gid[[thread_position_in_grid]]) {
341-
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
342-
return;
343-
}
344-
ushort2 gid_ = gid.xy;
345-
346-
ushort feature = gid.z;
347-
ushort outz = offset_buf[2] + feature;
348-
349-
half4 outtex = out.read(gid_, outz);
350-
half4 intex = in.read(gid_);
351-
if (feature == 0) {
352-
outtex.y = intex.x;
353-
outtex.z = intex.y;
354-
outtex.w = intex.z;
302+
half4 intex;
303+
if (append_features_in_is_arr) {
304+
intex = in_arr.read(gid_, inz);
355305
}
356306
else {
357-
outtex.x = intex.w;
358-
}
359-
360-
out.write(outtex, gid_, outz);
361-
}
362-
363-
kernel void append_features_off2(texture2d_array<half, access::read> in[[texture(0)]],
364-
texture2d_array<half, access::read_write> out[[texture(1)]],
365-
constant ushort* offset_buf[[buffer(0)]],
366-
ushort3 gid[[thread_position_in_grid]]) {
367-
if (gid.x >= out.get_width() || gid.y >= out.get_height() || gid.z >= offset_buf[4]) {
368-
return;
369-
}
307+
intex = in_tex.read(gid_);
308+
}
309+
out.write(intex, gid_, outz);
310+
}
311+
312+
constant bool prev_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
313+
constant bool prev_is_tex = !prev_is_arr;
314+
constant bool append_features_off_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
315+
constant bool append_features_off_in_is_tex = !append_features_off_in_is_arr;
316+
kernel void append_features_off(texture2d<half, access::read> in_tex[[texture(0), function_constant(append_features_off_in_is_tex)]],
317+
texture2d_array<half, access::read> in_arr[[texture(0), function_constant(append_features_off_in_is_arr)]],
318+
texture2d<half, access::read> prev_tex[[texture(1), function_constant(prev_is_tex)]],
319+
texture2d_array<half, access::read> prev_arr[[texture(1), function_constant(prev_is_arr)]],
320+
texture2d_array<half, access::write> out[[texture(2)]],
321+
constant ushort* offset_buf[[buffer(0)]],
322+
ushort3 gid[[thread_position_in_grid]]) {
370323
ushort2 gid_ = gid.xy;
371324
372325
ushort batch = gid.z / offset_buf[0];
373326
ushort feature = gid.z % offset_buf[0];
374327
ushort outz = batch * offset_buf[1] + offset_buf[2] + feature;
375328
ushort inz = batch * offset_buf[3] + feature;
376-
377-
half4 outtex = out.read(gid_, outz);
378-
half4 intex1 = in.read(gid_, inz);
329+
half4 outtex;
330+
if (prev_is_arr)
331+
outtex = prev_arr.read(gid_, batch);
332+
else
333+
outtex = prev_tex.read(gid_);
334+
half4 intex1;
335+
if (append_features_in_is_arr)
336+
intex1 = in_arr.read(gid_, inz);
337+
else
338+
intex1 = in_tex.read(gid_);
379339
if (feature == 0) {
380-
outtex.z = intex1.x;
381-
outtex.w = intex1.y;
340+
if (offset_buf[5] == 1)
341+
outtex.yzw = intex1.xyz;
342+
else if (offset_buf[5] == 2)
343+
outtex.zw = intex1.xy;
344+
else
345+
outtex.w = intex1.x;
382346
out.write(outtex, gid_, outz);
383347
return;
384348
}
385-
half4 intex0 = in.read(gid_, inz-1);
386-
outtex.x = intex0.z;
387-
outtex.y = intex0.w;
388-
outtex.z = intex1.x;
389-
outtex.w = intex1.y;
390-
391-
out.write(outtex, gid_, outz);
392-
}
393-
394-
kernel void append_features_off2_nonarray(texture2d<half, access::read> in[[texture(0)]],
395-
texture2d_array<half, access::read_write> out[[texture(1)]],
396-
constant ushort* offset_buf[[buffer(0)]],
397-
ushort3 gid[[thread_position_in_grid]]) {
398-
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
399-
return;
349+
half4 intex0;
350+
if (append_features_in_is_arr)
351+
intex0 = in_arr.read(gid_, inz-1);
352+
else
353+
intex0 = intex1;
354+
if (offset_buf[5] == 1) {
355+
outtex.x = intex0.w;
356+
outtex.yzw = intex1.xyz;
400357
}
401-
ushort2 gid_ = gid.xy;
402-
403-
ushort feature = gid.z;
404-
ushort outz = offset_buf[2] + feature;
405-
406-
half4 outtex = out.read(gid_, outz);
407-
half4 intex = in.read(gid_);
408-
if (feature == 0) {
409-
outtex.z = intex.x;
410-
outtex.w = intex.y;
358+
else if (offset_buf[5] == 2) {
359+
outtex.xy = intex0.zw;
360+
outtex.zw = intex1.xy;
411361
}
412362
else {
413-
outtex.x = intex.z;
414-
outtex.y = intex.w;
415-
}
416-
417-
out.write(outtex, gid_, outz);
418-
}
419-
420-
kernel void append_features_off3(texture2d_array<half, access::read> in[[texture(0)]],
421-
texture2d_array<half, access::read_write> out[[texture(1)]],
422-
constant ushort* offset_buf[[buffer(0)]],
423-
ushort3 gid[[thread_position_in_grid]]) {
424-
if (gid.x >= out.get_width() || gid.y >= out.get_height() || gid.z >= offset_buf[4]) {
425-
return;
426-
}
427-
ushort2 gid_ = gid.xy;
428-
429-
ushort batch = gid.z / offset_buf[0];
430-
ushort feature = gid.z % offset_buf[0];
431-
ushort outz = batch * offset_buf[1] + offset_buf[2] + feature;
432-
ushort inz = batch * offset_buf[3] + feature;
433-
434-
half4 outtex = out.read(gid_, outz);
435-
half4 intex1 = in.read(gid_, inz);
436-
if (feature == 0) {
363+
outtex.xyz = intex0.yzw;
437364
outtex.w = intex1.x;
438-
out.write(outtex, gid_, outz);
439-
return;
440-
}
441-
half4 intex0 = in.read(gid_, inz-1);
442-
outtex.x = intex0.y;
443-
outtex.y = intex0.z;
444-
outtex.z = intex0.w;
445-
outtex.w = intex1.x;
446-
447-
out.write(outtex, gid_, outz);
448-
}
449-
450-
kernel void append_features_off3_nonarray(texture2d<half, access::read> in[[texture(0)]],
451-
texture2d_array<half, access::read_write> out[[texture(1)]],
452-
constant ushort* offset_buf[[buffer(0)]],
453-
ushort3 gid[[thread_position_in_grid]]) {
454-
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
455-
return;
456-
}
457-
ushort2 gid_ = gid.xy;
458-
459-
ushort feature = gid.z;
460-
ushort outz = offset_buf[2] + feature;
461-
462-
half4 outtex = out.read(gid_, outz);
463-
half4 intex = in.read(gid_);
464-
if (feature == 0) {
465-
outtex.w = intex.x;
466-
}
467-
else {
468-
outtex.x = intex.y;
469-
outtex.y = intex.z;
470-
outtex.z = intex.w;
471365
}
472366
473367
out.write(outtex, gid_, outz);

0 commit comments

Comments
 (0)