We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 073a455 commit e54af2bCopy full SHA for e54af2b
questions/195_gradient-checkpointing/solution.py
@@ -1,3 +1,17 @@
1
-def your_function(...):
2
- # reference implementation
3
- ...
+import numpy as np
+
+def checkpoint_forward(funcs, input_arr):
4
+ """
5
+ Applies a list of functions in sequence to the input array, simulating gradient checkpointing by not storing intermediates.
6
7
+ Args:
8
+ funcs (list of callables): List of functions to apply in sequence.
9
+ input_arr (np.ndarray): Input numpy array.
10
11
+ Returns:
12
+ np.ndarray: The output after applying all functions, same shape as output of last function.
13
14
+ x = input_arr
15
+ for f in funcs:
16
+ x = f(x)
17
+ return x.astype(float)
0 commit comments