@@ -1283,12 +1283,18 @@ class LargeTensorLinear(torch.nn.Module):
12831283 def __init__ (self ):
12841284 super ().__init__ ()
12851285 hidden_dim = 4096
1286- self .linear1 = torch .nn .Linear (512 , hidden_dim )
1286+ self .linear1_1 = torch .nn .Linear (512 , hidden_dim )
1287+ self .linear1_2 = torch .nn .Linear (512 , hidden_dim )
1288+ self .linear1_3 = torch .nn .Linear (512 , hidden_dim )
12871289 self .linear2 = torch .nn .Linear (hidden_dim , 512 )
1290+ self .linear3 = torch .nn .Linear (hidden_dim , 512 )
1291+ self .linear4 = torch .nn .Linear (hidden_dim , 512 )
12881292
12891293 def forward (self , x ):
1290- x1 = self .linear1 (x ) + self .linear1 (x )
1291- return self .linear2 (x1 )
1294+ x1 = self .linear1_1 (x ) + self .linear1_1 (x )
1295+ x2 = self .linear1_2 (x ) + self .linear1_2 (x )
1296+ x3 = self .linear1_3 (x ) + self .linear1_3 (x )
1297+ return self .linear2 (x1 ) * self .linear3 (x2 ) * self .linear4 (x3 )
12921298
12931299
12941300class LayerNorm (torch .nn .Module ):
@@ -1371,6 +1377,19 @@ def forward(self, x):
13711377 return x + N
13721378
13731379
1380+ class LinalgVectorNorm (torch .nn .Module ):
1381+ def __init__ (self , ord = 2.0 , dim = None , keepdim = False ):
1382+ super ().__init__ ()
1383+ self .ord = ord
1384+ self .dim = dim
1385+ self .keepdim = keepdim
1386+
1387+ def forward (self , x ):
1388+ return torch .linalg .vector_norm (
1389+ x , ord = self .ord , dim = self .dim , keepdim = self .keepdim
1390+ )
1391+
1392+
13741393class Linear (torch .nn .Module ):
13751394 def __init__ (self , use_bias : bool = True ):
13761395 super ().__init__ ()
@@ -1380,17 +1399,24 @@ def forward(self, x):
13801399 return self .linear (x )
13811400
13821401
1383- class LinalgVectorNorm (torch .nn .Module ):
1384- def __init__ (self , ord = 2.0 , dim = None , keepdim = False ):
1402+ class LinearNonConstantWeight (torch .nn .Module ):
1403+ def __init__ (self ):
13851404 super ().__init__ ()
1386- self .ord = ord
1387- self .dim = dim
1388- self .keepdim = keepdim
1405+ self .input_dim = 512
1406+ self .output_dim = 128
1407+ self .linear = torch . nn . Linear ( self . input_dim , 3 * self . output_dim , True ). eval ()
13891408
13901409 def forward (self , x ):
1391- return torch . linalg . vector_norm (
1392- x , ord = self .ord , dim = self .dim , keepdim = self .keepdim
1410+ w_q , w_k , w_v = self . linear . weight . split (
1411+ [ self .output_dim , self .output_dim , self .output_dim ]
13931412 )
1413+ b_q , b_k , b_v = self .linear .bias .split (
1414+ [self .output_dim , self .output_dim , self .output_dim ]
1415+ )
1416+ q = torch .nn .functional .linear (x , w_q , b_q )
1417+ k = torch .nn .functional .linear (x , w_k , b_k )
1418+ v = torch .nn .functional .linear (x , w_v , b_v )
1419+ return q * k * v
13941420
13951421
13961422class Log (torch .nn .Module ):
@@ -1814,10 +1840,11 @@ def forward(self, x):
18141840
18151841
18161842class RmsNorm (torch .nn .Module ):
1817- def __init__ (self ):
1843+ def __init__ (self , eps = None ):
18181844 super ().__init__ ()
1819- self .eps = 1e-5
1820- self .rms = torch .nn .RMSNorm ([4 ], 1e-5 )
1845+ self .rms = torch .nn .RMSNorm ([4 ])
1846+ if eps :
1847+ self .rms = torch .nn .RMSNorm ([4 ], eps )
18211848
18221849 def forward (self , x ):
18231850 return self .rms (x )
@@ -2149,6 +2176,32 @@ def forward(self, x):
21492176 return a + self .idx_source [b ]
21502177
21512178
2179+ class Triu (torch .nn .Module ):
2180+ def __init__ (self , diagonal : Optional [int ] = None ):
2181+ super ().__init__ ()
2182+ self .diagonal = diagonal
2183+
2184+ def forward (self , x ):
2185+ if self .diagonal :
2186+ return torch .triu (x , diagonal = self .diagonal )
2187+ return torch .triu (x )
2188+
2189+
2190+ class TriuConstant (torch .nn .Module ):
2191+ def __init__ (self , diagonal , constant_dtype = torch .float32 ):
2192+ super ().__init__ ()
2193+ self .diagonal = diagonal
2194+ self .constant_dtype = constant_dtype
2195+ self .register_buffer ("mask" , torch .ones ((5 , 5 ), dtype = constant_dtype ))
2196+
2197+ def forward (self , x ):
2198+ mask = torch .triu (self .mask , diagonal = self .diagonal )
2199+ if self .constant_dtype == torch .bool :
2200+ mask = torch .zeros (x .shape , dtype = x .dtype ).masked_fill_ (mask , - 10000.0 )
2201+ # Add x to avoid no input in graph
2202+ return mask + x
2203+
2204+
21522205class Unbind (torch .nn .Module ):
21532206 def __init__ (self ):
21542207 super ().__init__ ()
0 commit comments