@@ -1283,12 +1283,18 @@ class LargeTensorLinear(torch.nn.Module):
12831283 def __init__ (self ):
12841284 super ().__init__ ()
12851285 hidden_dim = 4096
1286- self .linear1 = torch .nn .Linear (512 , hidden_dim )
1286+ self .linear1_1 = torch .nn .Linear (512 , hidden_dim )
1287+ self .linear1_2 = torch .nn .Linear (512 , hidden_dim )
1288+ self .linear1_3 = torch .nn .Linear (512 , hidden_dim )
12871289 self .linear2 = torch .nn .Linear (hidden_dim , 512 )
1290+ self .linear3 = torch .nn .Linear (hidden_dim , 512 )
1291+ self .linear4 = torch .nn .Linear (hidden_dim , 512 )
12881292
12891293 def forward (self , x ):
1290- x1 = self .linear1 (x ) + self .linear1 (x )
1291- return self .linear2 (x1 )
1294+ x1 = self .linear1_1 (x ) + self .linear1_1 (x )
1295+ x2 = self .linear1_2 (x ) + self .linear1_2 (x )
1296+ x3 = self .linear1_3 (x ) + self .linear1_3 (x )
1297+ return self .linear2 (x1 ) * self .linear3 (x2 ) * self .linear4 (x3 )
12921298
12931299
12941300class LayerNorm (torch .nn .Module ):
@@ -1380,6 +1386,26 @@ def forward(self, x):
13801386 return self .linear (x )
13811387
13821388
1389+ class LinearNonConstantWeight (torch .nn .Module ):
1390+ def __init__ (self ):
1391+ super ().__init__ ()
1392+ self .input_dim = 512
1393+ self .output_dim = 128
1394+ self .linear = torch .nn .Linear (self .input_dim , 3 * self .output_dim , True ).eval ()
1395+
1396+ def forward (self , x ):
1397+ w_q , w_k , w_v = self .linear .weight .split (
1398+ [self .output_dim , self .output_dim , self .output_dim ]
1399+ )
1400+ b_q , b_k , b_v = self .linear .bias .split (
1401+ [self .output_dim , self .output_dim , self .output_dim ]
1402+ )
1403+ q = torch .nn .functional .linear (x , w_q , b_q )
1404+ k = torch .nn .functional .linear (x , w_k , b_k )
1405+ v = torch .nn .functional .linear (x , w_v , b_v )
1406+ return q * k * v
1407+
1408+
13831409class LinalgVectorNorm (torch .nn .Module ):
13841410 def __init__ (self , ord = 2.0 , dim = None , keepdim = False ):
13851411 super ().__init__ ()
@@ -1814,10 +1840,11 @@ def forward(self, x):
18141840
18151841
18161842class RmsNorm (torch .nn .Module ):
1817- def __init__ (self ):
1843+ def __init__ (self , eps = None ):
18181844 super ().__init__ ()
1819- self .eps = 1e-5
1820- self .rms = torch .nn .RMSNorm ([4 ], 1e-5 )
1845+ self .rms = torch .nn .RMSNorm ([4 ])
1846+ if eps :
1847+ self .rms = torch .nn .RMSNorm ([4 ], eps )
18211848
18221849 def forward (self , x ):
18231850 return self .rms (x )
@@ -2141,6 +2168,32 @@ def forward(self, x):
21412168 return a + self .idx_source [b ]
21422169
21432170
2171+ class Triu (torch .nn .Module ):
2172+ def __init__ (self , diagonal : Optional [int ] = None ):
2173+ super ().__init__ ()
2174+ self .diagonal = diagonal
2175+
2176+ def forward (self , x ):
2177+ if self .diagonal :
2178+ return torch .triu (x , diagonal = self .diagonal )
2179+ return torch .triu (x )
2180+
2181+
2182+ class TriuConstant (torch .nn .Module ):
2183+ def __init__ (self , diagonal , constant_dtype = torch .float32 ):
2184+ super ().__init__ ()
2185+ self .diagonal = diagonal
2186+ self .constant_dtype = constant_dtype
2187+ self .register_buffer ("mask" , torch .ones ((5 , 5 ), dtype = constant_dtype ))
2188+
2189+ def forward (self , x ):
2190+ mask = torch .triu (self .mask , diagonal = self .diagonal )
2191+ if self .constant_dtype == torch .bool :
2192+ mask = torch .zeros (x .shape , dtype = x .dtype ).masked_fill_ (mask , - 10000.0 )
2193+ # Add x to avoid no input in graph
2194+ return mask + x
2195+
2196+
21442197class Unbind (torch .nn .Module ):
21452198 def __init__ (self ):
21462199 super ().__init__ ()
0 commit comments