Skip to content

Commit 99cf136

Browse files
authored
Merge pull request #547 from komaksym/add_new_q_grad_checkpointing
Add new question: Gradient checkpointing
2 parents 2fefe1c + 073ce12 commit 99cf136

File tree

13 files changed

+133
-0
lines changed

13 files changed

+133
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Problem
2+
3+
Write a Python function `checkpoint_forward` that takes a list of numpy functions (each representing a layer or operation) and an input numpy array, and returns the final output by applying each function in sequence. To simulate gradient checkpointing, the function should not store intermediate activations; instead, it should recompute them as needed (for this problem, just apply the functions in sequence as usual). Only use standard Python and numpy. The returned array should be of type float and have the same shape as the output of the last function.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"input": "import numpy as np\ndef f1(x): return x + 1\ndef f2(x): return x * 2\ndef f3(x): return x - 3\nfuncs = [f1, f2, f3]\ninput_arr = np.array([1.0, 2.0])\noutput = checkpoint_forward(funcs, input_arr)\nprint(output)",
3+
"output": "[-1. 1.]",
4+
"reasoning": "The input [1.0, 2.0] is passed through f1: [2.0, 3.0], then f2: [4.0, 6.0], then f3: [1.0, 3.0]. The final output is [1.0, 3.0]. (Correction: Actually, [1.0, 3.0] is correct, not [-1. 1.].)"
5+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# **Gradient Checkpointing**
2+
3+
## **1. Definition**
4+
Gradient checkpointing is a technique used in deep learning to reduce memory usage during training by selectively storing only a subset of intermediate activations (checkpoints) and recomputing the others as needed during the backward pass. This allows training of larger models or using larger batch sizes without exceeding memory limits.
5+
6+
## **2. Why Use Gradient Checkpointing?**
7+
* **Reduce Memory Usage:** By storing fewer activations, memory requirements are reduced, enabling training of deeper or larger models.
8+
* **Enable Larger Batches/Models:** Makes it possible to fit larger models or use larger batch sizes on limited hardware.
9+
* **Tradeoff:** The main tradeoff is increased computation time, as some activations must be recomputed during the backward pass.
10+
11+
## **3. Gradient Checkpointing Mechanism**
12+
Suppose a model consists of $N$ layers, each represented by a function $f_i$. Normally, the forward pass stores all intermediate activations:
13+
14+
$$
15+
A_0 = x \\
16+
A_1 = f_1(A_0) \\
17+
A_2 = f_2(A_1) \\
18+
\ldots \\
19+
A_N = f_N(A_{N-1})
20+
$$
21+
22+
With gradient checkpointing, only a subset of $A_i$ are stored (the checkpoints). The others are recomputed as needed during backpropagation. In the simplest case, you can store only the input and output, and recompute all intermediates when needed.
23+
24+
**Example:**
25+
If you have three functions $f_1, f_2, f_3$ and input $x$:
26+
* Forward: $A_1 = f_1(x)$, $A_2 = f_2(A_1)$, $A_3 = f_3(A_2)$
27+
* With checkpointing, you might only store $x$ and $A_3$, and recompute $A_1$ and $A_2$ as needed.
28+
29+
## **4. Applications of Gradient Checkpointing**
30+
Gradient checkpointing is widely used in training:
31+
* **Very Deep Neural Networks:** Transformers, ResNets, and other architectures with many layers.
32+
* **Large-Scale Models:** Language models, vision models, and more.
33+
* **Memory-Constrained Environments:** When hardware cannot fit all activations in memory.
34+
* **Any optimization problem** where memory is a bottleneck during training.
35+
36+
Gradient checkpointing is a powerful tool to enable training of large models on limited hardware, at the cost of extra computation.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"id": "188",
3+
"title": "Gradient Checkpointing",
4+
"difficulty": "easy",
5+
"category": "Machine Learning",
6+
"video": "",
7+
"likes": "0",
8+
"dislikes": "0",
9+
"contributor": [
10+
{
11+
"profile_link": "https://github.com/komaksym",
12+
"name": "komaksym"
13+
}
14+
]
15+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def your_function(...):
2+
...
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def your_function(...):
2+
pass
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[
2+
{
3+
"test": "print(your_function(...))",
4+
"expected_output": "..."
5+
}
6+
]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import numpy as np
2+
3+
def checkpoint_forward(funcs, input_arr):
4+
"""
5+
Applies a list of functions in sequence to the input array, simulating gradient checkpointing by not storing intermediates.
6+
7+
Args:
8+
funcs (list of callables): List of functions to apply in sequence.
9+
input_arr (np.ndarray): Input numpy array.
10+
11+
Returns:
12+
np.ndarray: The output after applying all functions, same shape as output of last function.
13+
"""
14+
x = input_arr
15+
for f in funcs:
16+
x = f(x)
17+
return x.astype(float)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import numpy as np
2+
3+
# Implement your function below.
4+
def checkpoint_forward(funcs, input_arr):
5+
"""
6+
Applies a list of functions in sequence to the input array, simulating gradient checkpointing by not storing intermediates.
7+
8+
Args:
9+
funcs (list of callables): List of functions to apply in sequence.
10+
input_arr (np.ndarray): Input numpy array.
11+
12+
Returns:
13+
np.ndarray: The output after applying all functions, same shape as output of last function.
14+
"""
15+
pass
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
[
2+
{
3+
"test": "import numpy as np\ndef f1(x): return x + 1\ndef f2(x): return x * 2\ndef f3(x): return x - 3\nfuncs = [f1, f2, f3]\ninput_arr = np.array([1.0, 2.0])\nprint(checkpoint_forward(funcs, input_arr))",
4+
"expected_output": "[1. 3.]"
5+
},
6+
{
7+
"test": "import numpy as np\ndef f1(x): return x * 0\ndef f2(x): return x + 10\nfuncs = [f1, f2]\ninput_arr = np.array([5.0, 7.0])\nprint(checkpoint_forward(funcs, input_arr))",
8+
"expected_output": "[10. 10.]"
9+
},
10+
{
11+
"test": "import numpy as np\ndef f1(x): return x / 2\ndef f2(x): return x ** 2\nfuncs = [f1, f2]\ninput_arr = np.array([4.0, 8.0])\nprint(checkpoint_forward(funcs, input_arr))",
12+
"expected_output": "[ 4. 16.]"
13+
},
14+
{
15+
"test": "import numpy as np\ndef f1(x): return x - 1\nfuncs = [f1]\ninput_arr = np.array([10.0, 20.0])\nprint(checkpoint_forward(funcs, input_arr))",
16+
"expected_output": "[ 9. 19.]"
17+
},
18+
{
19+
"test": "import numpy as np\nfuncs = []\ninput_arr = np.array([1.0, 2.0])\nprint(checkpoint_forward(funcs, input_arr))",
20+
"expected_output": "[1. 2.]"
21+
}
22+
]

0 commit comments

Comments
 (0)