Skip to content

Parallel Apache Arrow + DuckDB Solution to the One Trillion Row Challenge 26 mins 15 secs on Dell Workstation #3

@MurrayData

Description

@MurrayData

I wanted to test my new Dell Precision 7960 workstation on this task.

The hardware spec is: Intel(R) Xeon(R) w5-3435X CPU 16 core/32 thread max speed 4.7GHz, 512GB DDR5 LRDIMMs running at 4400 MHz, 4 x Samsung Pro 990 2TB Gen 4 NVMe in a RAID 0 in a Dell UltraSpeed card in a PCIe 5.0 x 16 slot, NVIDIA RTX A6000 (Ampere) GPU 48GB.

I tried several approaches, but settled for a native Apache Arrow table group by solution using parallel workers to execute the chunks. The first stage aggregation uses Apache Arrow tables to compute min, max, sum and count of temperature for each station in a group by.

def aggregate_chunk(start, inc, files):
    last = min(len(files), start+inc)
    station_ds = pa.dataset.dataset(files[start:last])
    station_table = station_ds.to_table() 
    result_table = pa.TableGroupBy(station_table, 'station').aggregate([('measure','min'),('measure','max'),('measure','sum'),('measure','count')])
    return result_table

Following concatenation of the group by tables, a second stage aggregation is run using DuckDB to group by station name and compute min and max of the aggregate and mean by dividing the aggregate sum by the aggregate count.

%%time
cpus = mp.cpu_count()
pool_size = cpus // 4
print(f'CPU thread count: {cpus}\nParallel workers pool size: {pool_size}\nFile batch size: {inc}')

with mp.Pool(pool_size) as pool:
    results = list(tqdm.tqdm(pool.imap(f, range(0,max_value,inc)), total=max_value//inc))

Which generates the following output:

CPU thread count: 32
Parallel workers pool size: 8
File batch size: 10
100%|██████████| 10000/10000 [26:14<00:00,  6.35it/s]
CPU times: user 6min 11s, sys: 1min 55s, total: 8min 7s
Wall time: 26min 14s

Interestingly the optimal solution, found by trial and error, was to use a smaller file batch size (10 files) and 8 parallel workers.

Following concatenation of the group by tables, a second stage aggregation is run using DuckDB to group and sort by station name then compute the min and max of the aggregate chunks and the mean computed by dividing the aggregate sum by the aggregate count.

query = """
        SELECT station,
        MIN(measure_min) AS measure_min,
        MAX(measure_min) AS measure_max,
        SUM(measure_sum) / SUM(measure_count) AS measure_mean
        FROM summary_table
        GROUP BY station
        ORDER BY station
        """
print(duckdb_explain(con, query))

Explanation:

physical_plan
┌───────────────────────────┐
│          ORDER_BY         │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│          ORDERS:          │
│ summary_table.station ASC │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│         PROJECTION        │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│          station          │
│        measure_min        │
│        measure_max        │
│        measure_mean       │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│       HASH_GROUP_BY       │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│             #0            │
│          min(#1)          │
│          max(#2)          │
│          sum(#3)          │
│          sum(#4)          │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│         PROJECTION        │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│          station          │
│        measure_min        │
│        measure_min        │
│        measure_sum        │
│       measure_count       │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│        ARROW_SCAN         │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│          station          │
│        measure_min        │
│        measure_sum        │
│       measure_count       │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│           EC: 1           │
└───────────────────────────┘                             

Execute the query:

%%time
result = con.execute(query)

CPU times: user 773 ms, sys: 152 ms, total: 926 ms
Wall time: 285 ms

Generates the output:

           station  measure_min  measure_max  measure_mean
0             Abha        -43.8        -21.6     18.000158
1          Abidjan        -34.4        -13.7     25.999815
2           Abéché        -33.6        -10.2     29.400132
3            Accra        -34.1        -13.2     26.400023
4      Addis Ababa        -58.0        -22.7     16.000132
..             ...          ...          ...           ...
407       Yinchuan        -53.1        -30.5      9.000168
408         Zagreb        -52.8        -29.2     10.699659
409  Zanzibar City        -34.2        -13.7     26.000115
410         Ürümqi        -52.9        -32.0      7.399789
411          İzmir        -45.3        -21.7     17.899969

[412 rows x 4 columns]

Total elapsed time:

print(f'Time taken: {dt.timedelta(seconds=time.time()-start)}')

Time taken: 0:26:15.065602

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions