The idea is to implement a benchmark function, e.g., running a predefined ML task with Torch to evaluate the computation capabilities of a specific computer. These capabilities may be quantified, e.g., as the flops.
Then, at model partition (on the starter node or in the src/llama/prepare_model.py script), it would be possible to assign a number of layers proportional to this value to each node to split the model in a balanced way.
Ideally, this benchmark would be performed once on each node to get the flops, and then this value would be inserted in the config.json and used to initialize the node.
Then, before partitioning the model, the starter node would request this value through a GET.
Example benchmark:
def measure_flops(device, matrix_size=1024, num_trials=10):
# Create random matrices
A = torch.randn(matrix_size, matrix_size, device=device)
B = torch.randn(matrix_size, matrix_size, device=device)
# Warm-up to ensure the device is ready (for GPUs)
C = torch.matmul(A, B)
# Measure the time taken for the operation
start_time = time.time()
for _ in range(num_trials):
C = torch.matmul(A, B)
end_time = time.time()
# Calculate the average time per multiplication
avg_time = (end_time - start_time) / num_trials
# Calculate the number of floating point operations: 2 * matrix_size^3 (for matrix multiplication)
flops = 2 * (matrix_size ** 3) / avg_time
return flops, avg_time
The idea is to implement a benchmark function, e.g., running a predefined ML task with Torch to evaluate the computation capabilities of a specific computer. These capabilities may be quantified, e.g., as the flops.
Then, at model partition (on the starter node or in the
src/llama/prepare_model.pyscript), it would be possible to assign a number of layers proportional to this value to each node to split the model in a balanced way.Ideally, this benchmark would be performed once on each node to get the flops, and then this value would be inserted in the config.json and used to initialize the node.
Then, before partitioning the model, the starter node would request this value through a GET.
Example benchmark: