@@ -71,9 +71,10 @@ export abstract class NDArrayMath {
7171 */
7272 enableDebugMode ( ) {
7373 this . debugMode = true ;
74- console . warn ( 'Debugging mode is ON. The output of every math call will ' +
75- 'be downloaded to CPU and checked for NaNs. ' +
76- 'This significantly impacts performance.' ) ;
74+ console . warn (
75+ 'Debugging mode is ON. The output of every math call will ' +
76+ 'be downloaded to CPU and checked for NaNs. ' +
77+ 'This significantly impacts performance.' ) ;
7778 }
7879
7980 /**
@@ -97,7 +98,7 @@ export abstract class NDArrayMath {
9798 endScope ( result : ScopeResult ) {
9899 let arraysToKeep = this . activeScopeNDArraysToKeep ;
99100 if ( result != null ) {
100- arraysToKeep = arraysToKeep . concat ( result as NDArray | NDArray [ ] ) ;
101+ arraysToKeep = arraysToKeep . concat ( result as NDArray | NDArray [ ] ) ;
101102 }
102103 // Dispose the current scope.
103104 for ( let i = 0 ; i < this . activeScope . length ; i ++ ) {
@@ -321,22 +322,15 @@ export abstract class NDArrayMath {
321322 protected abstract cloneInternal < T extends NDArray > ( ndarray : T ) : T ;
322323
323324 /**
324- * Reshapes an NDArray to a new shape. The size of the input NDArray must
325- * match the size of the requested shape.
326- * @param ndarray The input NDArray.
327- * @param newShape The new shape to reshape the NDArray to. Must be the same
328- * size as the NDArray.
325+ * @deprecated Please call reshape() directly on the ndarray object.
329326 */
330327 reshape < T1 extends NDArray , T2 extends NDArray > (
331328 ndarray : T1 , newShape : number [ ] ) : T2 {
332- util . assert (
333- ndarray . size === util . sizeFromShape ( newShape ) ,
334- `Error in reshape: old size ${ ndarray . size } must match new size ` +
335- `${ util . sizeFromShape ( newShape ) } .` ) ;
336- return this . track ( this . reshapeInternal < T1 , T2 > ( ndarray , newShape ) ) ;
329+ console . warn (
330+ 'math.reshape() is deprecated. Please call reshape() ' +
331+ 'directly on the ndarray object' ) ;
332+ return ndarray . reshape ( newShape ) ;
337333 }
338- protected abstract reshapeInternal < T1 extends NDArray , T2 extends NDArray > (
339- ndarray : T1 , newShape : number [ ] ) : T2 ;
340334
341335 /**
342336 * Extracts a slice from a matrix. The operation extraces a slice from input
@@ -1148,7 +1142,8 @@ export abstract class NDArrayMath {
11481142 * @param h Array of previous cell outputs.
11491143 * @return Tuple [nextCellStates, cellOutputs]
11501144 */
1151- multiRNNCell ( lstmCells : LSTMCell [ ] , data : Array2D , c : Array2D [ ] ,
1145+ multiRNNCell (
1146+ lstmCells : LSTMCell [ ] , data : Array2D , c : Array2D [ ] ,
11521147 h : Array2D [ ] ) : [ Array2D [ ] , Array2D [ ] ] {
11531148 util . assert (
11541149 data . shape [ 0 ] === 1 ,
@@ -1187,8 +1182,9 @@ export abstract class NDArrayMath {
11871182 * @param h Previous cell output.
11881183 * @return Tuple [nextCellState, cellOutput]
11891184 */
1190- basicLSTMCell ( forgetBias : Scalar , lstmKernel : Array2D , lstmBias : Array1D ,
1191- data : Array2D , c : Array2D , h : Array2D ) : [ Array2D , Array2D ] {
1185+ basicLSTMCell (
1186+ forgetBias : Scalar , lstmKernel : Array2D , lstmBias : Array1D , data : Array2D ,
1187+ c : Array2D , h : Array2D ) : [ Array2D , Array2D ] {
11921188 const res = this . scope ( ( ) => {
11931189 util . assert (
11941190 data . shape [ 0 ] === 1 ,
@@ -1207,25 +1203,25 @@ export abstract class NDArrayMath {
12071203
12081204 // i = input_gate, j = new_input, f = forget_gate, o = output_gate
12091205 const i = this . slice2D ( res , [ 0 , 0 ] , [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1210- const j = this . slice2D ( res , [ 0 , res . shape [ 1 ] / 4 * 1 ] ,
1211- [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1212- const f = this . slice2D ( res , [ 0 , res . shape [ 1 ] / 4 * 2 ] ,
1213- [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1214- const o = this . slice2D ( res , [ 0 , res . shape [ 1 ] / 4 * 3 ] ,
1215- [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1216-
1217- const newC = this . add (
1218- this . multiplyStrict ( c ,
1219- this . sigmoid ( this . scalarPlusArray ( forgetBias , f ) ) ) ,
1220- this . multiplyStrict ( this . sigmoid ( i ) , this . tanh ( j ) ) ) as Array2D ;
1221- const newH = this . multiplyStrict (
1222- this . tanh ( newC ) , this . sigmoid ( o ) ) as Array2D ;
1206+ const j = this . slice2D (
1207+ res , [ 0 , res . shape [ 1 ] / 4 * 1 ] , [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1208+ const f = this . slice2D (
1209+ res , [ 0 , res . shape [ 1 ] / 4 * 2 ] , [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1210+ const o = this . slice2D (
1211+ res , [ 0 , res . shape [ 1 ] / 4 * 3 ] , [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1212+
1213+ const newC =
1214+ this . add (
1215+ this . multiplyStrict (
1216+ c , this . sigmoid ( this . scalarPlusArray ( forgetBias , f ) ) ) ,
1217+ this . multiplyStrict ( this . sigmoid ( i ) , this . tanh ( j ) ) ) as Array2D ;
1218+ const newH =
1219+ this . multiplyStrict ( this . tanh ( newC ) , this . sigmoid ( o ) ) as Array2D ;
12231220
12241221 return [ newC , newH ] ;
12251222 } ) ;
12261223 return [ res [ 0 ] , res [ 1 ] ] ;
12271224 }
1228-
12291225}
12301226
12311227export enum MatrixOrientation {
0 commit comments