@@ -271,203 +271,97 @@ kernel void copy_offset_nonarray(texture2d<half, access::read> in[[texture(0)]],
271271 out.write(in.read(gid_), gid_, gid.z + offset_buf[0]);
272272}
273273
274- kernel void append_features_off0(texture2d_array<half, access::read> in[[texture(0)]],
275- texture2d_array<half, access::read_write> out[[texture(1)]],
276- constant ushort* offset_buf[[buffer(0)]],
277- ushort3 gid[[thread_position_in_grid]]) {
278- if (gid.x >= out.get_width() || gid.y >= out.get_height() || gid.z >= offset_buf[4]) {
279- return;
280- }
281- ushort2 gid_ = gid.xy;
282-
283- ushort batch = gid.z / offset_buf[0];
284- ushort feature = gid.z % offset_buf[0];
285- ushort outz = batch * offset_buf[1] + offset_buf[2] + feature;
286- ushort inz = batch * offset_buf[3] + feature;
287-
288- half4 intex1 = in.read(gid_, inz);
289- half4 outtex = intex1;
290-
291- out.write(outtex, gid_, outz);
292- }
293-
294- kernel void append_features_off0_nonarray(texture2d<half, access::read> in[[texture(0)]],
295- texture2d_array<half, access::read_write> out[[texture(1)]],
296- constant ushort* offset_buf[[buffer(0)]],
297- ushort3 gid[[thread_position_in_grid]]) {
298- if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
299- return;
300- }
274+ constant bool store_features_out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
275+ constant bool store_features_out_is_tex = !store_features_out_is_arr;
276+ kernel void store_features(texture2d_array<half, access::read> in[[texture(0)]],
277+ texture2d<half, access::write> out_tex[[texture(1), function_constant(store_features_out_is_tex)]],
278+ texture2d_array<half, access::write> out_arr[[texture(1), function_constant(store_features_out_is_arr)]],
279+ constant ushort* offset_buf[[buffer(0)]],
280+ ushort3 gid[[thread_position_in_grid]]) {
301281 ushort2 gid_ = gid.xy;
302- out.write(in.read(gid_), gid_, offset_buf[2]);
303- }
304-
305- kernel void append_features_off1(texture2d_array<half, access::read> in[[texture(0)]],
306- texture2d_array<half, access::read_write> out[[texture(1)]],
307- constant ushort* offset_buf[[buffer(0)]],
308- ushort3 gid[[thread_position_in_grid]]) {
309- if (gid.x >= out.get_width() || gid.y >= out.get_height() || gid.z >= offset_buf[4]) {
310- return;
311- }
282+ if (store_features_out_is_arr)
283+ out_arr.write(in.read(gid_, gid.z * offset_buf[1] + offset_buf[0]), gid_, gid.z);
284+ else
285+ out_tex.write(in.read(gid_, gid.z * offset_buf[1] + offset_buf[0]), gid_);
286+ }
287+
288+ constant bool append_features_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
289+ constant bool append_features_in_is_tex = !append_features_in_is_arr;
290+ kernel void append_features(texture2d<half, access::read> in_tex[[texture(0), function_constant(append_features_in_is_tex)]],
291+ texture2d_array<half, access::read> in_arr[[texture(0), function_constant(append_features_in_is_arr)]],
292+ texture2d_array<half, access::write> out[[texture(1)]],
293+ constant ushort* offset_buf[[buffer(0)]],
294+ ushort3 gid[[thread_position_in_grid]]) {
312295 ushort2 gid_ = gid.xy;
313296
314297 ushort batch = gid.z / offset_buf[0];
315298 ushort feature = gid.z % offset_buf[0];
316299 ushort outz = batch * offset_buf[1] + offset_buf[2] + feature;
317300 ushort inz = batch * offset_buf[3] + feature;
318301
319- half4 outtex = out.read(gid_, outz);
320- half4 intex1 = in.read(gid_, inz);
321- if (feature == 0) {
322- outtex.y = intex1.x;
323- outtex.z = intex1.y;
324- outtex.w = intex1.z;
325- out.write(outtex, gid_, outz);
326- return;
327- }
328- half4 intex0 = in.read(gid_, inz-1);
329- outtex.x = intex0.w;
330- outtex.y = intex1.x;
331- outtex.z = intex1.y;
332- outtex.w = intex1.z;
333-
334- out.write(outtex, gid_, outz);
335- }
336-
337- kernel void append_features_off1_nonarray(texture2d<half, access::read> in[[texture(0)]],
338- texture2d_array<half, access::read_write> out[[texture(1)]],
339- constant ushort* offset_buf[[buffer(0)]],
340- ushort3 gid[[thread_position_in_grid]]) {
341- if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
342- return;
343- }
344- ushort2 gid_ = gid.xy;
345-
346- ushort feature = gid.z;
347- ushort outz = offset_buf[2] + feature;
348-
349- half4 outtex = out.read(gid_, outz);
350- half4 intex = in.read(gid_);
351- if (feature == 0) {
352- outtex.y = intex.x;
353- outtex.z = intex.y;
354- outtex.w = intex.z;
302+ half4 intex;
303+ if (append_features_in_is_arr) {
304+ intex = in_arr.read(gid_, inz);
355305 }
356306 else {
357- outtex.x = intex.w;
358- }
359-
360- out.write(outtex, gid_, outz);
361- }
362-
363- kernel void append_features_off2(texture2d_array<half, access::read> in[[texture(0)]],
364- texture2d_array<half, access::read_write> out[[texture(1)]],
365- constant ushort* offset_buf[[buffer(0)]],
366- ushort3 gid[[thread_position_in_grid]]) {
367- if (gid.x >= out.get_width() || gid.y >= out.get_height() || gid.z >= offset_buf[4]) {
368- return;
369- }
307+ intex = in_tex.read(gid_);
308+ }
309+ out.write(intex, gid_, outz);
310+ }
311+
312+ constant bool prev_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
313+ constant bool prev_is_tex = !prev_is_arr;
314+ constant bool append_features_off_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
315+ constant bool append_features_off_in_is_tex = !append_features_off_in_is_arr;
316+ kernel void append_features_off(texture2d<half, access::read> in_tex[[texture(0), function_constant(append_features_off_in_is_tex)]],
317+ texture2d_array<half, access::read> in_arr[[texture(0), function_constant(append_features_off_in_is_arr)]],
318+ texture2d<half, access::read> prev_tex[[texture(1), function_constant(prev_is_tex)]],
319+ texture2d_array<half, access::read> prev_arr[[texture(1), function_constant(prev_is_arr)]],
320+ texture2d_array<half, access::write> out[[texture(2)]],
321+ constant ushort* offset_buf[[buffer(0)]],
322+ ushort3 gid[[thread_position_in_grid]]) {
370323 ushort2 gid_ = gid.xy;
371324
372325 ushort batch = gid.z / offset_buf[0];
373326 ushort feature = gid.z % offset_buf[0];
374327 ushort outz = batch * offset_buf[1] + offset_buf[2] + feature;
375328 ushort inz = batch * offset_buf[3] + feature;
376-
377- half4 outtex = out.read(gid_, outz);
378- half4 intex1 = in.read(gid_, inz);
329+ half4 outtex;
330+ if (prev_is_arr)
331+ outtex = prev_arr.read(gid_, batch);
332+ else
333+ outtex = prev_tex.read(gid_);
334+ half4 intex1;
335+ if (append_features_in_is_arr)
336+ intex1 = in_arr.read(gid_, inz);
337+ else
338+ intex1 = in_tex.read(gid_);
379339 if (feature == 0) {
380- outtex.z = intex1.x;
381- outtex.w = intex1.y;
340+ if (offset_buf[5] == 1)
341+ outtex.yzw = intex1.xyz;
342+ else if (offset_buf[5] == 2)
343+ outtex.zw = intex1.xy;
344+ else
345+ outtex.w = intex1.x;
382346 out.write(outtex, gid_, outz);
383347 return;
384348 }
385- half4 intex0 = in.read(gid_, inz-1);
386- outtex.x = intex0.z;
387- outtex.y = intex0.w;
388- outtex.z = intex1.x;
389- outtex.w = intex1.y;
390-
391- out.write(outtex, gid_, outz);
392- }
393-
394- kernel void append_features_off2_nonarray(texture2d<half, access::read> in[[texture(0)]],
395- texture2d_array<half, access::read_write> out[[texture(1)]],
396- constant ushort* offset_buf[[buffer(0)]],
397- ushort3 gid[[thread_position_in_grid]]) {
398- if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
399- return;
349+ half4 intex0;
350+ if (append_features_in_is_arr)
351+ intex0 = in_arr.read(gid_, inz-1);
352+ else
353+ intex0 = intex1;
354+ if (offset_buf[5] == 1) {
355+ outtex.x = intex0.w;
356+ outtex.yzw = intex1.xyz;
400357 }
401- ushort2 gid_ = gid.xy;
402-
403- ushort feature = gid.z;
404- ushort outz = offset_buf[2] + feature;
405-
406- half4 outtex = out.read(gid_, outz);
407- half4 intex = in.read(gid_);
408- if (feature == 0) {
409- outtex.z = intex.x;
410- outtex.w = intex.y;
358+ else if (offset_buf[5] == 2) {
359+ outtex.xy = intex0.zw;
360+ outtex.zw = intex1.xy;
411361 }
412362 else {
413- outtex.x = intex.z;
414- outtex.y = intex.w;
415- }
416-
417- out.write(outtex, gid_, outz);
418- }
419-
420- kernel void append_features_off3(texture2d_array<half, access::read> in[[texture(0)]],
421- texture2d_array<half, access::read_write> out[[texture(1)]],
422- constant ushort* offset_buf[[buffer(0)]],
423- ushort3 gid[[thread_position_in_grid]]) {
424- if (gid.x >= out.get_width() || gid.y >= out.get_height() || gid.z >= offset_buf[4]) {
425- return;
426- }
427- ushort2 gid_ = gid.xy;
428-
429- ushort batch = gid.z / offset_buf[0];
430- ushort feature = gid.z % offset_buf[0];
431- ushort outz = batch * offset_buf[1] + offset_buf[2] + feature;
432- ushort inz = batch * offset_buf[3] + feature;
433-
434- half4 outtex = out.read(gid_, outz);
435- half4 intex1 = in.read(gid_, inz);
436- if (feature == 0) {
363+ outtex.xyz = intex0.yzw;
437364 outtex.w = intex1.x;
438- out.write(outtex, gid_, outz);
439- return;
440- }
441- half4 intex0 = in.read(gid_, inz-1);
442- outtex.x = intex0.y;
443- outtex.y = intex0.z;
444- outtex.z = intex0.w;
445- outtex.w = intex1.x;
446-
447- out.write(outtex, gid_, outz);
448- }
449-
450- kernel void append_features_off3_nonarray(texture2d<half, access::read> in[[texture(0)]],
451- texture2d_array<half, access::read_write> out[[texture(1)]],
452- constant ushort* offset_buf[[buffer(0)]],
453- ushort3 gid[[thread_position_in_grid]]) {
454- if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
455- return;
456- }
457- ushort2 gid_ = gid.xy;
458-
459- ushort feature = gid.z;
460- ushort outz = offset_buf[2] + feature;
461-
462- half4 outtex = out.read(gid_, outz);
463- half4 intex = in.read(gid_);
464- if (feature == 0) {
465- outtex.w = intex.x;
466- }
467- else {
468- outtex.x = intex.y;
469- outtex.y = intex.z;
470- outtex.z = intex.w;
471365 }
472366
473367 out.write(outtex, gid_, outz);
0 commit comments