Skip to content

Commit 8cedf41

Browse files
authored
feat: state management (#29)
1 parent 3e4052f commit 8cedf41

File tree

3 files changed

+161
-24
lines changed

3 files changed

+161
-24
lines changed

crates/bit_rev/src/peer_connection.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::{
2020
peer::PeerAddr,
2121
peer_state::{PeerState, PeerStates},
2222
protocol::{Protocol, ProtocolError},
23-
session::PieceWork,
23+
session::{DownloadState, PieceWork},
2424
utils,
2525
};
2626

@@ -226,6 +226,7 @@ pub struct PeerHandler {
226226
requests_sem: Semaphore,
227227
peer: PeerAddr,
228228
torrent_downloaded_state: Arc<TorrentDownloadedState>,
229+
download_state: Arc<Mutex<DownloadState>>,
229230
}
230231

231232
impl PeerHandler {
@@ -237,6 +238,7 @@ impl PeerHandler {
237238
peers_state: Arc<PeerStates>,
238239
//pieces: Vec<PieceWork>,
239240
torrent_downloaded_state: Arc<TorrentDownloadedState>,
241+
download_state: Arc<Mutex<DownloadState>>,
240242
) -> Self {
241243
Self {
242244
unchoke_notify: unchoked_notify,
@@ -249,6 +251,7 @@ impl PeerHandler {
249251
peer_writer_tx,
250252
peer,
251253
torrent_downloaded_state,
254+
download_state,
252255
//torrent_downloaded_state: Arc::new(TorrentDownloadedState {
253256
//
254257
// semaphore: Semaphore::new(1),
@@ -278,6 +281,14 @@ impl PeerHandler {
278281
}
279282
}
280283

284+
pub fn get_download_state(&self) -> DownloadState {
285+
*self.download_state.lock().unwrap()
286+
}
287+
288+
pub fn is_downloading(&self) -> bool {
289+
self.get_download_state() == DownloadState::Downloading
290+
}
291+
281292
// The job of this is to request chunks and also to keep peer alive.
282293
// The moment this ends, the peer is disconnected.
283294
pub async fn task_peer_chunk_requester(&self) -> Result<(), anyhow::Error> {
@@ -307,6 +318,11 @@ impl PeerHandler {
307318
};
308319

309320
loop {
321+
// Wait while not downloading
322+
while !self.is_downloading() {
323+
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
324+
}
325+
310326
update_interest(self, true)?;
311327

312328
trace!("waiting for unchoke");
@@ -335,6 +351,14 @@ impl PeerHandler {
335351

336352
let mut offset: u32 = 0;
337353
while offset < piece.length {
354+
// Check download state before requesting each block
355+
if !self.is_downloading() {
356+
// Wait while not downloading
357+
while !self.is_downloading() {
358+
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
359+
}
360+
}
361+
338362
loop {
339363
match (tokio::time::timeout(
340364
Duration::from_secs(5),

crates/bit_rev/src/session.rs

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::sync::Arc;
1+
use std::sync::{Arc, Mutex};
22

33
use crate::file::{self, TorrentMeta};
44
use crate::peer_state::PeerStates;
@@ -8,6 +8,13 @@ use crate::utils;
88
use dashmap::DashMap;
99
use flume::Receiver;
1010

11+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12+
pub enum DownloadState {
13+
Init,
14+
Downloading,
15+
Paused,
16+
}
17+
1118
#[derive(Debug, Clone, Copy)]
1219
pub struct PieceWork {
1320
pub index: u32,
@@ -31,6 +38,7 @@ pub struct State {
3138

3239
pub struct Session {
3340
pub streams: DashMap<[u8; 20], TrackerPeers>,
41+
pub download_state: Arc<Mutex<DownloadState>>,
3442
}
3543

3644
pub struct AddTorrentOptions {
@@ -70,9 +78,56 @@ impl Session {
7078
pub fn new() -> Self {
7179
Self {
7280
streams: DashMap::new(),
81+
download_state: Arc::new(Mutex::new(DownloadState::Init)),
82+
}
83+
}
84+
85+
pub fn start_downloading(&self) {
86+
{
87+
let mut state = self.download_state.lock().unwrap();
88+
*state = DownloadState::Downloading;
89+
}
90+
for entry in self.streams.iter() {
91+
entry.value().set_download_state(DownloadState::Downloading);
7392
}
7493
}
7594

95+
pub fn pause(&self) {
96+
{
97+
let mut state = self.download_state.lock().unwrap();
98+
*state = DownloadState::Paused;
99+
}
100+
for entry in self.streams.iter() {
101+
entry.value().set_download_state(DownloadState::Paused);
102+
}
103+
}
104+
105+
pub fn resume(&self) {
106+
{
107+
let mut state = self.download_state.lock().unwrap();
108+
*state = DownloadState::Downloading;
109+
}
110+
for entry in self.streams.iter() {
111+
entry.value().set_download_state(DownloadState::Downloading);
112+
}
113+
}
114+
115+
pub fn get_download_state(&self) -> DownloadState {
116+
*self.download_state.lock().unwrap()
117+
}
118+
119+
pub fn is_paused(&self) -> bool {
120+
self.get_download_state() == DownloadState::Paused
121+
}
122+
123+
pub fn is_downloading(&self) -> bool {
124+
self.get_download_state() == DownloadState::Downloading
125+
}
126+
127+
pub fn is_init(&self) -> bool {
128+
self.get_download_state() == DownloadState::Init
129+
}
130+
76131
pub async fn add_torrent(
77132
&self,
78133
add_torrent: AddTorrentOptions,
@@ -91,6 +146,7 @@ impl Session {
91146
peer_states,
92147
have_broadcast.clone(),
93148
pr_rx.clone(),
149+
self.download_state.clone(),
94150
);
95151

96152
let pieces_of_work = (0..(torrent.piece_hashes.len()) as u64)
@@ -128,6 +184,9 @@ impl Session {
128184
self.streams
129185
.insert(torrent.info_hash, tracker_stream.clone());
130186

187+
// Start downloading
188+
self.start_downloading();
189+
131190
Ok(AddTorrentResult {
132191
torrent,
133192
torrent_meta,

crates/bit_rev/src/tracker_peers.rs

Lines changed: 76 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use serde_bencode::de;
22
use std::sync::{atomic::AtomicBool, Arc, Mutex};
3-
use tokio::{select, sync::Semaphore};
3+
use tokio::{select, sync::Semaphore, time::sleep};
44
use tracing::{debug, error};
55

66
use crate::{
@@ -11,7 +11,7 @@ use crate::{
1111
},
1212
peer_state::PeerStates,
1313
protocol_udp::request_udp_peers,
14-
session::{PieceResult, PieceWork},
14+
session::{DownloadState, PieceResult, PieceWork},
1515
};
1616

1717
#[derive(Debug, Clone)]
@@ -23,6 +23,7 @@ pub struct TrackerPeers {
2323
pub piece_rx: flume::Receiver<FullPiece>,
2424
pub pr_rx: flume::Receiver<PieceResult>,
2525
pub have_broadcast: Arc<tokio::sync::broadcast::Sender<u32>>,
26+
pub download_state: Arc<Mutex<DownloadState>>,
2627
}
2728

2829
impl TrackerPeers {
@@ -33,6 +34,7 @@ impl TrackerPeers {
3334
peer_states: Arc<PeerStates>,
3435
have_broadcast: Arc<tokio::sync::broadcast::Sender<u32>>,
3536
pr_rx: flume::Receiver<PieceResult>,
37+
download_state: Arc<Mutex<DownloadState>>,
3638
) -> TrackerPeers {
3739
let (sender, receiver) = flume::unbounded();
3840
TrackerPeers {
@@ -43,9 +45,31 @@ impl TrackerPeers {
4345
pr_rx,
4446
peer_states,
4547
have_broadcast,
48+
download_state,
4649
}
4750
}
4851

52+
pub fn set_download_state(&self, state: DownloadState) {
53+
let mut current_state = self.download_state.lock().unwrap();
54+
*current_state = state;
55+
}
56+
57+
pub fn get_download_state(&self) -> DownloadState {
58+
*self.download_state.lock().unwrap()
59+
}
60+
61+
pub fn is_paused(&self) -> bool {
62+
self.get_download_state() == DownloadState::Paused
63+
}
64+
65+
pub fn is_downloading(&self) -> bool {
66+
self.get_download_state() == DownloadState::Downloading
67+
}
68+
69+
pub fn is_init(&self) -> bool {
70+
self.get_download_state() == DownloadState::Init
71+
}
72+
4973
pub async fn connect(&self, pieces_of_work: Vec<PieceWork>) {
5074
let info_hash = self.torrent_meta.info_hash;
5175
let peer_id = self.peer_id;
@@ -70,6 +94,7 @@ impl TrackerPeers {
7094
let peer_states = self.peer_states.clone();
7195
let piece_tx = self.piece_tx.clone();
7296
let have_broadcast = self.have_broadcast.clone();
97+
let download_state = self.download_state.clone();
7398
let torrent_downloaded_state = Arc::new(TorrentDownloadedState {
7499
semaphore: Semaphore::new(1),
75100
pieces: pieces_of_work
@@ -84,13 +109,21 @@ impl TrackerPeers {
84109
});
85110
tokio::spawn(async move {
86111
loop {
112+
// Wait while not downloading
113+
while {
114+
let state = *download_state.lock().unwrap();
115+
state != DownloadState::Downloading
116+
} {
117+
sleep(std::time::Duration::from_millis(100)).await;
118+
}
87119
// Handle TCP trackers
88120
for tracker in tcp_trackers.clone() {
89121
let torrent_meta = torrent_meta.clone();
90122
let peer_states = peer_states.clone();
91123
let piece_tx = piece_tx.clone();
92124
let have_broadcast = have_broadcast.clone();
93125
let torrent_downloaded_state = torrent_downloaded_state.clone();
126+
let download_state = download_state.clone();
94127
tokio::spawn(async move {
95128
let url = file::build_tracker_url(&torrent_meta, &peer_id, 6881, &tracker)
96129
.map_err(|e| {
@@ -105,12 +138,16 @@ impl TrackerPeers {
105138
Ok(new_peers) => {
106139
process_peers(
107140
new_peers,
108-
info_hash,
109-
peer_id,
110-
peer_states.clone(),
111-
piece_tx.clone(),
112-
have_broadcast.clone(),
113-
torrent_downloaded_state.clone(),
141+
PeerProcessorConfig {
142+
info_hash,
143+
peer_id,
144+
peer_states: peer_states.clone(),
145+
piece_tx: piece_tx.clone(),
146+
have_broadcast: have_broadcast.clone(),
147+
torrent_downloaded_state: torrent_downloaded_state
148+
.clone(),
149+
download_state: download_state.clone(),
150+
},
114151
)
115152
.await;
116153

@@ -141,6 +178,7 @@ impl TrackerPeers {
141178
let piece_tx = piece_tx.clone();
142179
let have_broadcast = have_broadcast.clone();
143180
let torrent_downloaded_state = torrent_downloaded_state.clone();
181+
let download_state = download_state.clone();
144182
tokio::spawn(async move {
145183
match request_udp_peers(&tracker, &torrent_meta, &peer_id, 6881).await {
146184
Ok(udp_response) => {
@@ -156,12 +194,15 @@ impl TrackerPeers {
156194

157195
process_peers(
158196
new_peers,
159-
info_hash,
160-
peer_id,
161-
peer_states.clone(),
162-
piece_tx.clone(),
163-
have_broadcast.clone(),
164-
torrent_downloaded_state.clone(),
197+
PeerProcessorConfig {
198+
info_hash,
199+
peer_id,
200+
peer_states: peer_states.clone(),
201+
piece_tx: piece_tx.clone(),
202+
have_broadcast: have_broadcast.clone(),
203+
torrent_downloaded_state: torrent_downloaded_state.clone(),
204+
download_state: download_state.clone(),
205+
},
165206
)
166207
.await;
167208

@@ -185,24 +226,36 @@ impl TrackerPeers {
185226
}
186227
}
187228

188-
async fn process_peers(
189-
new_peers: Vec<std::net::SocketAddr>,
229+
struct PeerProcessorConfig {
190230
info_hash: [u8; 20],
191231
peer_id: [u8; 20],
192232
peer_states: Arc<PeerStates>,
193233
piece_tx: flume::Sender<FullPiece>,
194234
have_broadcast: Arc<tokio::sync::broadcast::Sender<u32>>,
195235
torrent_downloaded_state: Arc<TorrentDownloadedState>,
196-
) {
236+
download_state: Arc<Mutex<DownloadState>>,
237+
}
238+
239+
async fn process_peers(new_peers: Vec<std::net::SocketAddr>, config: PeerProcessorConfig) {
240+
let info_hash = config.info_hash;
241+
let peer_id = config.peer_id;
242+
197243
for peer in new_peers {
198-
if peer_states.clone().states.contains_key(&peer) {
244+
// Skip processing new peers if not downloading
245+
let current_state = *config.download_state.lock().unwrap();
246+
if current_state != DownloadState::Downloading {
247+
continue;
248+
}
249+
250+
if config.peer_states.clone().states.contains_key(&peer) {
199251
continue;
200252
}
201253

202-
let piece_tx = piece_tx.clone();
203-
let have_broadcast = have_broadcast.clone();
204-
let torrent_downloaded_state = torrent_downloaded_state.clone();
205-
let peer_states = peer_states.clone();
254+
let piece_tx = config.piece_tx.clone();
255+
let have_broadcast = config.have_broadcast.clone();
256+
let torrent_downloaded_state = config.torrent_downloaded_state.clone();
257+
let peer_states = config.peer_states.clone();
258+
let download_state = config.download_state.clone();
206259

207260
tokio::spawn(async move {
208261
let unchoke_notify = tokio::sync::Notify::new();
@@ -215,6 +268,7 @@ async fn process_peers(
215268
peer_writer_tx.clone(),
216269
peer_states.clone(),
217270
torrent_downloaded_state.clone(),
271+
download_state.clone(),
218272
));
219273

220274
let peer_connection =

0 commit comments

Comments
 (0)