diff --git a/src/tensor.ts b/src/tensor.ts index c82d9b5..148761c 100644 --- a/src/tensor.ts +++ b/src/tensor.ts @@ -717,7 +717,7 @@ export class Pow { */ forward(a: Tensor, n: number): Tensor { // Build cache to use in backward step: - this.cache = a; + this.cache = [a, n]; // Call recursive function: const z = new Tensor( @@ -737,12 +737,12 @@ export class Pow { backward(dz: Tensor, z: Tensor) { // Get data from cache: - const a = this.cache; + const [a, n] = this.cache; // Find gradients relative to "a", and pass them downstream: if (requiresGrad(a)) { - // d/da(e^a) = e^a, apply the chain rule to the derivative of e^a: - const da = new Tensor(_mul(2, _mul(a.data, dz.data))); + // d/da(a^n) = na^(n-1), apply the chain rule to the derivative of a^n: + const da = new Tensor(_mul(n, _mul(_pow(a.data, n-1), dz.data))); a.backward(da, z); } } @@ -1725,11 +1725,13 @@ function _matmul(a: Array, b: Array, kernel: any): Array { function _pow(a: Array | number, n: number): Array | number { // If a is a number, exponentiate it. If not, exponentiate all of its elements: - let z = a; - for (let i = 0; i < n - 1; i++) { - z = _mul(z, a); + if (typeof a === "number") { + return a ** n; + } else if (a instanceof Array) { + return a.map((element: Array) => _pow(element, n)) + } else { + throw new TypeError("the input data is not a number."); } - return z; } function _sqrt(a: Array | number): Array | number {