From 0b69d799a1f5cb336ec25768088fc2cf61feec53 Mon Sep 17 00:00:00 2001 From: Felipe Cardozo Date: Tue, 30 Sep 2025 01:22:15 -0300 Subject: [PATCH] feat: state management --- crates/bit_rev/src/peer_connection.rs | 26 ++++++- crates/bit_rev/src/session.rs | 61 ++++++++++++++++- crates/bit_rev/src/tracker_peers.rs | 98 +++++++++++++++++++++------ 3 files changed, 161 insertions(+), 24 deletions(-) diff --git a/crates/bit_rev/src/peer_connection.rs b/crates/bit_rev/src/peer_connection.rs index 978a332..40c1a83 100644 --- a/crates/bit_rev/src/peer_connection.rs +++ b/crates/bit_rev/src/peer_connection.rs @@ -20,7 +20,7 @@ use crate::{ peer::PeerAddr, peer_state::{PeerState, PeerStates}, protocol::{Protocol, ProtocolError}, - session::PieceWork, + session::{DownloadState, PieceWork}, utils, }; @@ -226,6 +226,7 @@ pub struct PeerHandler { requests_sem: Semaphore, peer: PeerAddr, torrent_downloaded_state: Arc, + download_state: Arc>, } impl PeerHandler { @@ -237,6 +238,7 @@ impl PeerHandler { peers_state: Arc, //pieces: Vec, torrent_downloaded_state: Arc, + download_state: Arc>, ) -> Self { Self { unchoke_notify: unchoked_notify, @@ -249,6 +251,7 @@ impl PeerHandler { peer_writer_tx, peer, torrent_downloaded_state, + download_state, //torrent_downloaded_state: Arc::new(TorrentDownloadedState { // // semaphore: Semaphore::new(1), @@ -278,6 +281,14 @@ impl PeerHandler { } } + pub fn get_download_state(&self) -> DownloadState { + *self.download_state.lock().unwrap() + } + + pub fn is_downloading(&self) -> bool { + self.get_download_state() == DownloadState::Downloading + } + // The job of this is to request chunks and also to keep peer alive. // The moment this ends, the peer is disconnected. pub async fn task_peer_chunk_requester(&self) -> Result<(), anyhow::Error> { @@ -307,6 +318,11 @@ impl PeerHandler { }; loop { + // Wait while not downloading + while !self.is_downloading() { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + update_interest(self, true)?; trace!("waiting for unchoke"); @@ -335,6 +351,14 @@ impl PeerHandler { let mut offset: u32 = 0; while offset < piece.length { + // Check download state before requesting each block + if !self.is_downloading() { + // Wait while not downloading + while !self.is_downloading() { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + } + loop { match (tokio::time::timeout( Duration::from_secs(5), diff --git a/crates/bit_rev/src/session.rs b/crates/bit_rev/src/session.rs index 4a7f065..9327ad3 100644 --- a/crates/bit_rev/src/session.rs +++ b/crates/bit_rev/src/session.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use crate::file::{self, TorrentMeta}; use crate::peer_state::PeerStates; @@ -8,6 +8,13 @@ use crate::utils; use dashmap::DashMap; use flume::Receiver; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DownloadState { + Init, + Downloading, + Paused, +} + #[derive(Debug, Clone, Copy)] pub struct PieceWork { pub index: u32, @@ -31,6 +38,7 @@ pub struct State { pub struct Session { pub streams: DashMap<[u8; 20], TrackerPeers>, + pub download_state: Arc>, } pub struct AddTorrentOptions { @@ -70,9 +78,56 @@ impl Session { pub fn new() -> Self { Self { streams: DashMap::new(), + download_state: Arc::new(Mutex::new(DownloadState::Init)), + } + } + + pub fn start_downloading(&self) { + { + let mut state = self.download_state.lock().unwrap(); + *state = DownloadState::Downloading; + } + for entry in self.streams.iter() { + entry.value().set_download_state(DownloadState::Downloading); } } + pub fn pause(&self) { + { + let mut state = self.download_state.lock().unwrap(); + *state = DownloadState::Paused; + } + for entry in self.streams.iter() { + entry.value().set_download_state(DownloadState::Paused); + } + } + + pub fn resume(&self) { + { + let mut state = self.download_state.lock().unwrap(); + *state = DownloadState::Downloading; + } + for entry in self.streams.iter() { + entry.value().set_download_state(DownloadState::Downloading); + } + } + + pub fn get_download_state(&self) -> DownloadState { + *self.download_state.lock().unwrap() + } + + pub fn is_paused(&self) -> bool { + self.get_download_state() == DownloadState::Paused + } + + pub fn is_downloading(&self) -> bool { + self.get_download_state() == DownloadState::Downloading + } + + pub fn is_init(&self) -> bool { + self.get_download_state() == DownloadState::Init + } + pub async fn add_torrent( &self, add_torrent: AddTorrentOptions, @@ -91,6 +146,7 @@ impl Session { peer_states, have_broadcast.clone(), pr_rx.clone(), + self.download_state.clone(), ); let pieces_of_work = (0..(torrent.piece_hashes.len()) as u64) @@ -128,6 +184,9 @@ impl Session { self.streams .insert(torrent.info_hash, tracker_stream.clone()); + // Start downloading + self.start_downloading(); + Ok(AddTorrentResult { torrent, torrent_meta, diff --git a/crates/bit_rev/src/tracker_peers.rs b/crates/bit_rev/src/tracker_peers.rs index e291b34..e63e40a 100644 --- a/crates/bit_rev/src/tracker_peers.rs +++ b/crates/bit_rev/src/tracker_peers.rs @@ -1,6 +1,6 @@ use serde_bencode::de; use std::sync::{atomic::AtomicBool, Arc, Mutex}; -use tokio::{select, sync::Semaphore}; +use tokio::{select, sync::Semaphore, time::sleep}; use tracing::{debug, error}; use crate::{ @@ -11,7 +11,7 @@ use crate::{ }, peer_state::PeerStates, protocol_udp::request_udp_peers, - session::{PieceResult, PieceWork}, + session::{DownloadState, PieceResult, PieceWork}, }; #[derive(Debug, Clone)] @@ -23,6 +23,7 @@ pub struct TrackerPeers { pub piece_rx: flume::Receiver, pub pr_rx: flume::Receiver, pub have_broadcast: Arc>, + pub download_state: Arc>, } impl TrackerPeers { @@ -33,6 +34,7 @@ impl TrackerPeers { peer_states: Arc, have_broadcast: Arc>, pr_rx: flume::Receiver, + download_state: Arc>, ) -> TrackerPeers { let (sender, receiver) = flume::unbounded(); TrackerPeers { @@ -43,9 +45,31 @@ impl TrackerPeers { pr_rx, peer_states, have_broadcast, + download_state, } } + pub fn set_download_state(&self, state: DownloadState) { + let mut current_state = self.download_state.lock().unwrap(); + *current_state = state; + } + + pub fn get_download_state(&self) -> DownloadState { + *self.download_state.lock().unwrap() + } + + pub fn is_paused(&self) -> bool { + self.get_download_state() == DownloadState::Paused + } + + pub fn is_downloading(&self) -> bool { + self.get_download_state() == DownloadState::Downloading + } + + pub fn is_init(&self) -> bool { + self.get_download_state() == DownloadState::Init + } + pub async fn connect(&self, pieces_of_work: Vec) { let info_hash = self.torrent_meta.info_hash; let peer_id = self.peer_id; @@ -70,6 +94,7 @@ impl TrackerPeers { let peer_states = self.peer_states.clone(); let piece_tx = self.piece_tx.clone(); let have_broadcast = self.have_broadcast.clone(); + let download_state = self.download_state.clone(); let torrent_downloaded_state = Arc::new(TorrentDownloadedState { semaphore: Semaphore::new(1), pieces: pieces_of_work @@ -84,6 +109,13 @@ impl TrackerPeers { }); tokio::spawn(async move { loop { + // Wait while not downloading + while { + let state = *download_state.lock().unwrap(); + state != DownloadState::Downloading + } { + sleep(std::time::Duration::from_millis(100)).await; + } // Handle TCP trackers for tracker in tcp_trackers.clone() { let torrent_meta = torrent_meta.clone(); @@ -91,6 +123,7 @@ impl TrackerPeers { let piece_tx = piece_tx.clone(); let have_broadcast = have_broadcast.clone(); let torrent_downloaded_state = torrent_downloaded_state.clone(); + let download_state = download_state.clone(); tokio::spawn(async move { let url = file::build_tracker_url(&torrent_meta, &peer_id, 6881, &tracker) .map_err(|e| { @@ -105,12 +138,16 @@ impl TrackerPeers { Ok(new_peers) => { process_peers( new_peers, - info_hash, - peer_id, - peer_states.clone(), - piece_tx.clone(), - have_broadcast.clone(), - torrent_downloaded_state.clone(), + PeerProcessorConfig { + info_hash, + peer_id, + peer_states: peer_states.clone(), + piece_tx: piece_tx.clone(), + have_broadcast: have_broadcast.clone(), + torrent_downloaded_state: torrent_downloaded_state + .clone(), + download_state: download_state.clone(), + }, ) .await; @@ -141,6 +178,7 @@ impl TrackerPeers { let piece_tx = piece_tx.clone(); let have_broadcast = have_broadcast.clone(); let torrent_downloaded_state = torrent_downloaded_state.clone(); + let download_state = download_state.clone(); tokio::spawn(async move { match request_udp_peers(&tracker, &torrent_meta, &peer_id, 6881).await { Ok(udp_response) => { @@ -156,12 +194,15 @@ impl TrackerPeers { process_peers( new_peers, - info_hash, - peer_id, - peer_states.clone(), - piece_tx.clone(), - have_broadcast.clone(), - torrent_downloaded_state.clone(), + PeerProcessorConfig { + info_hash, + peer_id, + peer_states: peer_states.clone(), + piece_tx: piece_tx.clone(), + have_broadcast: have_broadcast.clone(), + torrent_downloaded_state: torrent_downloaded_state.clone(), + download_state: download_state.clone(), + }, ) .await; @@ -185,24 +226,36 @@ impl TrackerPeers { } } -async fn process_peers( - new_peers: Vec, +struct PeerProcessorConfig { info_hash: [u8; 20], peer_id: [u8; 20], peer_states: Arc, piece_tx: flume::Sender, have_broadcast: Arc>, torrent_downloaded_state: Arc, -) { + download_state: Arc>, +} + +async fn process_peers(new_peers: Vec, config: PeerProcessorConfig) { + let info_hash = config.info_hash; + let peer_id = config.peer_id; + for peer in new_peers { - if peer_states.clone().states.contains_key(&peer) { + // Skip processing new peers if not downloading + let current_state = *config.download_state.lock().unwrap(); + if current_state != DownloadState::Downloading { + continue; + } + + if config.peer_states.clone().states.contains_key(&peer) { continue; } - let piece_tx = piece_tx.clone(); - let have_broadcast = have_broadcast.clone(); - let torrent_downloaded_state = torrent_downloaded_state.clone(); - let peer_states = peer_states.clone(); + let piece_tx = config.piece_tx.clone(); + let have_broadcast = config.have_broadcast.clone(); + let torrent_downloaded_state = config.torrent_downloaded_state.clone(); + let peer_states = config.peer_states.clone(); + let download_state = config.download_state.clone(); tokio::spawn(async move { let unchoke_notify = tokio::sync::Notify::new(); @@ -215,6 +268,7 @@ async fn process_peers( peer_writer_tx.clone(), peer_states.clone(), torrent_downloaded_state.clone(), + download_state.clone(), )); let peer_connection =