Skip to content

Commit 5901d2c

Browse files
authored
Fix #230: Modify Barrier::wait to return BarrierWaitResult (#254)
* Fix #230: Modify Barrier::wait to return BarrierWaitResult * Move Barrier implementation to barrier.rs * Create simple test * Bug fix: resolve infinite loop * Handles the case where Barrier::wait is called more than the specified number of times * Fix typo * Fix typo * Verify the algorithm of Barrier::wait with TLA+ * Fix typo * Fix typo * Fix the implementation of Barrier::wait
1 parent 2db0df9 commit 5901d2c

File tree

6 files changed

+237
-43
lines changed

6 files changed

+237
-43
lines changed

applications/tests/test_measure_channel/src/lib.rs

Lines changed: 5 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@ extern crate alloc;
44

55
use alloc::{format, sync::Arc, vec::Vec};
66
use awkernel_async_lib::{
7-
channel::bounded,
8-
pubsub::{self, Attribute, Publisher, Subscriber},
9-
scheduler::SchedulerType,
10-
spawn, uptime_nano,
7+
channel::bounded, scheduler::SchedulerType, spawn, sync::barrier::Barrier, uptime_nano,
118
};
12-
use core::{sync::atomic::AtomicUsize, time::Duration};
9+
use core::time::Duration;
1310
use serde::Serialize;
1411

1512
const NUM_TASKS: [usize; 11] = [1000, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000];
@@ -27,41 +24,6 @@ struct MeasureResult {
2724
average: f64,
2825
}
2926

30-
#[derive(Clone)]
31-
struct Barrier {
32-
count: Arc<AtomicUsize>,
33-
tx: Arc<Publisher<()>>,
34-
rx: Subscriber<()>,
35-
}
36-
37-
impl Barrier {
38-
fn new(count: usize) -> Self {
39-
let attr = Attribute {
40-
queue_size: 1,
41-
..Attribute::default()
42-
};
43-
let (tx, rx) = pubsub::create_pubsub(attr);
44-
45-
Self {
46-
count: Arc::new(AtomicUsize::new(count)),
47-
tx: Arc::new(tx),
48-
rx,
49-
}
50-
}
51-
52-
async fn wait(&mut self) {
53-
if self
54-
.count
55-
.fetch_sub(1, core::sync::atomic::Ordering::Relaxed)
56-
== 1
57-
{
58-
self.tx.send(()).await;
59-
} else {
60-
self.rx.recv().await;
61-
}
62-
}
63-
}
64-
6527
pub async fn run() {
6628
let mut result = alloc::vec::Vec::with_capacity(NUM_TASKS.len());
6729
for num_task in NUM_TASKS.iter() {
@@ -76,15 +38,15 @@ pub async fn run() {
7638
}
7739

7840
async fn measure_task(num_task: usize, num_bytes: usize) -> MeasureResult {
79-
let barrier = Barrier::new(num_task * 2);
41+
let barrier = Arc::new(Barrier::new(num_task * 2));
8042
let mut server_join = alloc::vec::Vec::new();
8143
let mut client_join = alloc::vec::Vec::new();
8244

8345
for i in 0..num_task {
8446
let (tx1, rx1) = bounded::new::<Vec<u8>>(bounded::Attribute::default());
8547
let (tx2, rx2) = bounded::new::<Vec<u8>>(bounded::Attribute::default());
8648

87-
let mut barrier2 = barrier.clone();
49+
let barrier2 = barrier.clone();
8850
let hdl = spawn(
8951
format!("{i}-server").into(),
9052
async move {
@@ -108,7 +70,7 @@ async fn measure_task(num_task: usize, num_bytes: usize) -> MeasureResult {
10870

10971
server_join.push(hdl);
11072

111-
let mut barrier2 = barrier.clone();
73+
let barrier2 = barrier.clone();
11274
let hdl = spawn(
11375
format!("{i}-client").into(),
11476
async move {

awkernel_async_lib/src/sync.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pub use awkernel_lib::sync::mutex as raw_mutex;
2+
pub mod barrier;
23
pub mod mutex;
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
use super::mutex::AsyncLock;
2+
use crate::pubsub::{self, Attribute, Publisher, Subscriber};
3+
use alloc::{vec, vec::Vec};
4+
5+
struct BarrierState {
6+
count: usize,
7+
}
8+
9+
/// A barrier enables multiple threads to synchronize the beginning of some computation.
10+
pub struct Barrier {
11+
lock: AsyncLock<BarrierState>,
12+
num_threads: usize,
13+
tx: Publisher<()>,
14+
rxs: Vec<Subscriber<()>>,
15+
}
16+
17+
/// `BarrierWaitResult` is returned by `Barrier::wait` when all threads in `Barrier` have rendezvoused.
18+
pub struct BarrierWaitResult(bool);
19+
20+
impl BarrierWaitResult {
21+
pub fn is_reader(&self) -> bool {
22+
self.0
23+
}
24+
}
25+
26+
impl Barrier {
27+
/// Creates a new barrier that can block a given number of threads.
28+
pub fn new(n: usize) -> Self {
29+
let attr = Attribute {
30+
queue_size: 1,
31+
..Attribute::default()
32+
};
33+
let (tx, rx) = pubsub::create_pubsub(attr);
34+
35+
let mut rxs = vec![rx.clone(); n - 2];
36+
rxs.push(rx);
37+
38+
Self {
39+
lock: AsyncLock::new(BarrierState { count: 0 }),
40+
num_threads: n,
41+
tx,
42+
rxs,
43+
}
44+
}
45+
46+
/// Blocks the current thread until all threads have redezvoused here.
47+
/// A single (arbitrary) thread will receive `BarrierWaitResult(true)` when returning from this function, and other threads will receive `BarrierWaitResult(false)`.
48+
pub async fn wait(&self) -> BarrierWaitResult {
49+
let mut lock = self.lock.lock().await;
50+
let count = lock.count;
51+
if count < (self.num_threads - 1) {
52+
lock.count += 1;
53+
drop(lock);
54+
self.rxs[count].recv().await;
55+
BarrierWaitResult(false)
56+
} else {
57+
lock.count = 0;
58+
drop(lock);
59+
self.tx.send(()).await;
60+
BarrierWaitResult(true)
61+
}
62+
}
63+
}
64+
65+
#[cfg(test)]
66+
mod tests {
67+
use super::*;
68+
use alloc::sync::Arc;
69+
use core::sync::atomic::{AtomicUsize, Ordering};
70+
71+
#[test]
72+
fn test_simple_async_barrier() {
73+
let barrier = Arc::new(Barrier::new(10));
74+
let num_waits = Arc::new(AtomicUsize::new(0));
75+
let num_leaders = Arc::new(AtomicUsize::new(0));
76+
let tasks = crate::mini_task::Tasks::new();
77+
78+
for _ in 0..10 {
79+
let barrier = barrier.clone();
80+
let num_waits = num_waits.clone();
81+
let num_leaders = num_leaders.clone();
82+
let task = async move {
83+
num_waits.fetch_add(1, Ordering::Relaxed);
84+
85+
if barrier.wait().await.is_reader() {
86+
num_leaders.fetch_add(1, Ordering::Relaxed);
87+
}
88+
// Verify that Barrier synchronizes the specified number of threads
89+
assert_eq!(num_waits.load(Ordering::Relaxed), 10);
90+
91+
// It is safe to call Barrier::wait again
92+
barrier.wait().await;
93+
};
94+
95+
tasks.spawn(task);
96+
}
97+
tasks.run();
98+
99+
// Verify that only one thread receives BarrierWaitResult(true)
100+
assert_eq!(num_leaders.load(Ordering::Relaxed), 1);
101+
}
102+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Specification of Barrier implementation
2+
## How to run
3+
4+
1. Download tla2tools (https://github.com/tlaplus/tlaplus/releases)
5+
6+
2. Translate PlusCal to TLA+
7+
```bash
8+
java -cp tla2tools.jar pcal.trans barrier.tla
9+
```
10+
11+
3. Run TLC
12+
```bash
13+
java -jar tla2tools.jar -config barrier.cfg barrier.tla
14+
```
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
SPECIFICATION Spec
2+
\* Add statements after this line.
3+
CONSTANT Threads = {1, 2, 3, 4}
4+
CONSTANT N = 2
5+
INVARIANT BarrierInvariant
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
----------------- MODULE barrier ----------------
2+
EXTENDS TLC, Integers, FiniteSets, Sequences
3+
CONSTANTS Threads, N
4+
ASSUME N \in Nat
5+
ASSUME Threads \in SUBSET Nat
6+
7+
\* It is obvious that a deadlock wil ocuur if this conditon is not satisfied.
8+
ASSUME Cardinality(Threads) % N = 0
9+
10+
(*--algorithm barrier
11+
12+
\* `count` records how many times `wait` has been called.
13+
\* `num_blocked` records the number of blocked threads.
14+
variables
15+
count = 0,
16+
num_blocked = 0,
17+
blocked = FALSE;
18+
19+
\* If `count` < N, then the thread is blocked. otherwise, all blocked threads are awakened.
20+
\* This property implies that Barrier does not block more than N threads.
21+
define
22+
BarrierInvariant == num_blocked = count % N
23+
end define;
24+
25+
\* Note that `wait` is modeled as an atomic operation.
26+
\* Therefore, the implementation needs to use mechanisms such as locks.
27+
procedure wait() begin
28+
Wait:
29+
count := count + 1;
30+
if count % N /= 0 then
31+
num_blocked := num_blocked + 1;
32+
blocked := TRUE;
33+
Await:
34+
await ~blocked;
35+
return;
36+
else
37+
num_blocked := 0;
38+
blocked := FALSE;
39+
return;
40+
end if;
41+
end procedure;
42+
43+
fair process thread \in Threads begin
44+
Body:
45+
call wait();
46+
end process;
47+
48+
end algorithm*)
49+
\* BEGIN TRANSLATION (chksum(pcal) = "78d1002e" /\ chksum(tla) = "8098b806")
50+
VARIABLES pc, count, num_blocked, blocked, stack
51+
52+
(* define statement *)
53+
BarrierInvariant == num_blocked = count % N
54+
55+
56+
vars == << pc, count, num_blocked, blocked, stack >>
57+
58+
ProcSet == (Threads)
59+
60+
Init == (* Global variables *)
61+
/\ count = 0
62+
/\ num_blocked = 0
63+
/\ blocked = FALSE
64+
/\ stack = [self \in ProcSet |-> << >>]
65+
/\ pc = [self \in ProcSet |-> "Body"]
66+
67+
Wait(self) == /\ pc[self] = "Wait"
68+
/\ count' = count + 1
69+
/\ IF count' % N /= 0
70+
THEN /\ num_blocked' = num_blocked + 1
71+
/\ blocked' = TRUE
72+
/\ pc' = [pc EXCEPT ![self] = "Await"]
73+
/\ stack' = stack
74+
ELSE /\ num_blocked' = 0
75+
/\ blocked' = FALSE
76+
/\ pc' = [pc EXCEPT ![self] = Head(stack[self]).pc]
77+
/\ stack' = [stack EXCEPT ![self] = Tail(stack[self])]
78+
79+
Await(self) == /\ pc[self] = "Await"
80+
/\ ~blocked
81+
/\ pc' = [pc EXCEPT ![self] = Head(stack[self]).pc]
82+
/\ stack' = [stack EXCEPT ![self] = Tail(stack[self])]
83+
/\ UNCHANGED << count, num_blocked, blocked >>
84+
85+
wait(self) == Wait(self) \/ Await(self)
86+
87+
Body(self) == /\ pc[self] = "Body"
88+
/\ stack' = [stack EXCEPT ![self] = << [ procedure |-> "wait",
89+
pc |-> "Done" ] >>
90+
\o stack[self]]
91+
/\ pc' = [pc EXCEPT ![self] = "Wait"]
92+
/\ UNCHANGED << count, num_blocked, blocked >>
93+
94+
thread(self) == Body(self)
95+
96+
(* Allow infinite stuttering to prevent deadlock on termination. *)
97+
Terminating == /\ \A self \in ProcSet: pc[self] = "Done"
98+
/\ UNCHANGED vars
99+
100+
Next == (\E self \in ProcSet: wait(self))
101+
\/ (\E self \in Threads: thread(self))
102+
\/ Terminating
103+
104+
Spec == /\ Init /\ [][Next]_vars
105+
/\ \A self \in Threads : WF_vars(thread(self)) /\ WF_vars(wait(self))
106+
107+
Termination == <>(\A self \in ProcSet: pc[self] = "Done")
108+
109+
\* END TRANSLATION
110+
====

0 commit comments

Comments
 (0)