@@ -380,15 +380,16 @@ class MergerTest3T1LD : public MergerTestBase {
380380// /
381381// / Tests with both undef and dense input.
382382// /
383- class  MergerTest3T1LU  : public  MergerTestBase  {
383+ 
384+ class  MergerTest4T1LU  : public  MergerTestBase  {
384385protected: 
385386  //  Our three tensors (two inputs, one output).
386-   const  unsigned  t0 = 0 , t1 = 1 , t2 = 2 ;
387+   const  unsigned  t0 = 0 , t1 = 1 , t2 = 2 , t3 =  3 ;
387388
388389  //  Our single loop.
389390  const  unsigned  l0 = 0 ;
390391
391-   MergerTest3T1LU () : MergerTestBase(3 , 1 ) {
392+   MergerTest4T1LU () : MergerTestBase(4 , 1 ) {
392393    //  Tensor 0: undef input vector.
393394    merger.addExp (Kind::kTensor , t0, -1u );
394395    merger.setDimLevelFormat (t0, l0, DimLevelFormat (DimLvlType::kUndef ));
@@ -397,43 +398,110 @@ class MergerTest3T1LU : public MergerTestBase {
397398    merger.addExp (Kind::kTensor , t1, -1u );
398399    merger.setDimLevelFormat (t1, l0, DimLevelFormat (DimLvlType::kDense ));
399400
400-     //  Tensor 2: dense output  vector.
401+     //  Tensor 2: undef input  vector.
401402    merger.addExp (Kind::kTensor , t2, -1u );
402-     merger.setDimLevelFormat (t2, l0, DimLevelFormat (DimLvlType::kDense ));
403+     merger.setDimLevelFormat (t2, l0, DimLevelFormat (DimLvlType::kUndef ));
404+ 
405+     //  Tensor 3: dense output vector.
406+     merger.addExp (Kind::kTensor , t3, -1u );
407+     merger.setDimLevelFormat (t3, l0, DimLevelFormat (DimLvlType::kDense ));
408+   }
409+ };
410+ 
411+ // /
412+ // / Tests with operation on sparse output.
413+ // /
414+ 
415+ class  MergerTest3T1L_SO  : public  MergerTestBase  {
416+ protected: 
417+   //  Our three tensors (two inputs, one output, one synthetic).
418+   const  unsigned  t0 = 0 , t1 = 1 , t2 = 2 , t3 = 3 ;
419+ 
420+   //  Our single loop.
421+   const  unsigned  l0 = 0 ;
422+ 
423+   MergerTest3T1L_SO () : MergerTestBase(3 , 1 ) {
424+     merger.setHasSparseOut (true );
425+ 
426+     //  Tensor 0: undef input vector.
427+     merger.addExp (Kind::kTensor , t0, -1u );
428+     merger.setDimLevelFormat (t0, l0, DimLevelFormat (DimLvlType::kUndef ));
429+ 
430+     //  Tensor 1: undef input vector.
431+     merger.addExp (Kind::kTensor , t1, -1u );
432+     merger.setDimLevelFormat (t1, l0, DimLevelFormat (DimLvlType::kUndef ));
433+ 
434+     //  Tensor 2: sparse output vector.
435+     merger.addExp (Kind::kTensor , t2, -1u );
436+     merger.setDimLevelFormat (t2, l0, DimLevelFormat (DimLvlType::kCompressed ));
403437  }
404438};
439+ 
405440} //  namespace
406441
407- // / Vector multiplication (conjunction) of 2  vectors, i.e.;
408- // /   a(i) = b(i) * c(i)
442+ // / Vector multiplication (conjunction) of 3  vectors, i.e.;
443+ // /   a(i) = b(i) * c(i) * d(i) 
409444// / which should form the single lattice point
410445// / {
411- // /   lat( i_00_U i_01_D / (tensor_0 * tensor_1) )
446+ // /   lat( i_00_U i_01_D i_02_U  / (tensor_0 * tensor_1 * tensor2 ) )
412447// / }
413448// / after optimization, the dense dimesion should be kept, despite it appears
414- // / after  the undef dimension 
449+ // / in  the middle 
415450// / {
416- // /   lat( i_01_D / (tensor_0 * tensor_1) )
451+ // /   lat( i_01_D / (tensor_0 * tensor_1 * tensor2 ) )
417452// / }
418- #define  IMPL_MERGER_TEST_CONJ (OP )                                              \
419-   TEST_F (MergerTest3T1LU, vector_##OP) {                                       \
420-     auto  e = OP##Expr (t0, t1);                                                 \
453+ #define  IMPL_MERGER_TEST_CONJ_CONJ_UNDEF (CONJ1, CONJ2 )                         \
454+   TEST_F (MergerTest4T1LU, vector_##CONJ1##_##CONJ2) {                          \
455+     auto  em = CONJ1##Expr (t0, t1);                                             \
456+     auto  e = CONJ2##Expr (em, t2);                                              \
421457    auto  p0 = tensorPattern (t0);                                               \
422458    auto  p1 = tensorPattern (t1);                                               \
459+     auto  p2 = tensorPattern (t2);                                               \
423460    auto  s = merger.buildLattices (e, l0);                                      \
424-                                                                                \
425461    expectNumLatPoints (s, 1 );                                                  \
426-     expectLatPoint (s, lat (0 ), OP##Pattern (p0, p1),                             \
427-                    loopsToBits ({{l0, t0}, {l0, t1}}));                         \
428-                                                                                \
462+     expectLatPoint (s, lat (0 ), CONJ2##Pattern (CONJ1##Pattern (p0, p1), p2),      \
463+                    loopsToBits ({{l0, t0}, {l0, t1}, {l0, t2}}));               \
429464    s = merger.optimizeSet (s);                                                 \
430465    expectNumLatPoints (s, 1 );                                                  \
431-     expectLatPoint (s, lat (0 ), OP ##Pattern (p0, p1), loopsToBits ({{l0, t1}}),     \
432-                    true );                                                       \
466+     expectLatPoint (s, lat (0 ), CONJ2 ##Pattern (CONJ1## Pattern ( p0, p1), p2),       \
467+                    loopsToBits ({{l0, t1}}),  true );                             \
433468  }
434- FOREVERY_COMMON_CONJ_BINOP (IMPL_MERGER_TEST_CONJ)
435469
436- #undef  IMPL_MERGER_TEST_CONJ
470+ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP (IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
471+ 
472+ #undef  IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
473+ 
474+ // / Vector multiplication (conjunction) of 2 vectors, i.e.;
475+ // /   o(i) = b(i) * c(i) * o(i)
476+ // / which should form the single lattice point (note how a synthetic tensor
477+ // / i_03_U is created for the sparse output)
478+ // / {
479+ // /   lat( i_00_U i_01_U i_03_U / (tensor_0 * tensor_1 * output_tensor_2) )
480+ // / }
481+ // / after optimization, the synthetic tensor should be preserved.
482+ // / {
483+ // /   lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
484+ // / }
485+ #define  IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT (CONJ1, CONJ2 )                    \
486+   TEST_F (MergerTest3T1L_SO, vector_##CONJ1##_##CONJ2) {                        \
487+     auto  em = CONJ1##Expr (t0, t1);                                             \
488+     auto  e = CONJ2##Expr (em, t2);                                              \
489+     auto  p0 = tensorPattern (t0);                                               \
490+     auto  p1 = tensorPattern (t1);                                               \
491+     auto  p2 = tensorPattern (t2);                                               \
492+     auto  s = merger.buildLattices (e, l0);                                      \
493+     expectNumLatPoints (s, 1 );                                                  \
494+     expectLatPoint (s, lat (0 ), CONJ2##Pattern (CONJ1##Pattern (p0, p1), p2),      \
495+                    loopsToBits ({{l0, t0}, {l0, t1}, {l0, t3}}));               \
496+     s = merger.optimizeSet (s);                                                 \
497+     expectNumLatPoints (s, 1 );                                                  \
498+     expectLatPoint (s, lat (0 ), CONJ2##Pattern (CONJ1##Pattern (p0, p1), p2),      \
499+                    loopsToBits ({{l0, t3}}), true );                             \
500+   }
501+ 
502+ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP (IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
503+ 
504+ #undef  IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
437505
438506// / Vector addition (disjunction) of 2 vectors. i.e.;
439507// /   a(i) = b(i) + c(i)
0 commit comments