Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion crates/bit_rev/src/peer_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
peer::PeerAddr,
peer_state::{PeerState, PeerStates},
protocol::{Protocol, ProtocolError},
session::PieceWork,
session::{DownloadState, PieceWork},
utils,
};

Expand Down Expand Up @@ -226,6 +226,7 @@ pub struct PeerHandler {
requests_sem: Semaphore,
peer: PeerAddr,
torrent_downloaded_state: Arc<TorrentDownloadedState>,
download_state: Arc<Mutex<DownloadState>>,
}

impl PeerHandler {
Expand All @@ -237,6 +238,7 @@ impl PeerHandler {
peers_state: Arc<PeerStates>,
//pieces: Vec<PieceWork>,
torrent_downloaded_state: Arc<TorrentDownloadedState>,
download_state: Arc<Mutex<DownloadState>>,
) -> Self {
Self {
unchoke_notify: unchoked_notify,
Expand All @@ -249,6 +251,7 @@ impl PeerHandler {
peer_writer_tx,
peer,
torrent_downloaded_state,
download_state,
//torrent_downloaded_state: Arc::new(TorrentDownloadedState {
//
// semaphore: Semaphore::new(1),
Expand Down Expand Up @@ -278,6 +281,14 @@ impl PeerHandler {
}
}

pub fn get_download_state(&self) -> DownloadState {
*self.download_state.lock().unwrap()
}

pub fn is_downloading(&self) -> bool {
self.get_download_state() == DownloadState::Downloading
}

// The job of this is to request chunks and also to keep peer alive.
// The moment this ends, the peer is disconnected.
pub async fn task_peer_chunk_requester(&self) -> Result<(), anyhow::Error> {
Expand Down Expand Up @@ -307,6 +318,11 @@ impl PeerHandler {
};

loop {
// Wait while not downloading
while !self.is_downloading() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}

update_interest(self, true)?;

trace!("waiting for unchoke");
Expand Down Expand Up @@ -335,6 +351,14 @@ impl PeerHandler {

let mut offset: u32 = 0;
while offset < piece.length {
// Check download state before requesting each block
if !self.is_downloading() {
// Wait while not downloading
while !self.is_downloading() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}

loop {
match (tokio::time::timeout(
Duration::from_secs(5),
Expand Down
61 changes: 60 additions & 1 deletion crates/bit_rev/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use crate::file::{self, TorrentMeta};
use crate::peer_state::PeerStates;
Expand All @@ -8,6 +8,13 @@ use crate::utils;
use dashmap::DashMap;
use flume::Receiver;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DownloadState {
Init,
Downloading,
Paused,
}

#[derive(Debug, Clone, Copy)]
pub struct PieceWork {
pub index: u32,
Expand All @@ -31,6 +38,7 @@ pub struct State {

pub struct Session {
pub streams: DashMap<[u8; 20], TrackerPeers>,
pub download_state: Arc<Mutex<DownloadState>>,
}

pub struct AddTorrentOptions {
Expand Down Expand Up @@ -70,9 +78,56 @@ impl Session {
pub fn new() -> Self {
Self {
streams: DashMap::new(),
download_state: Arc::new(Mutex::new(DownloadState::Init)),
}
}

pub fn start_downloading(&self) {
{
let mut state = self.download_state.lock().unwrap();
*state = DownloadState::Downloading;
}
for entry in self.streams.iter() {
entry.value().set_download_state(DownloadState::Downloading);
}
}

pub fn pause(&self) {
{
let mut state = self.download_state.lock().unwrap();
*state = DownloadState::Paused;
}
for entry in self.streams.iter() {
entry.value().set_download_state(DownloadState::Paused);
}
}

pub fn resume(&self) {
{
let mut state = self.download_state.lock().unwrap();
*state = DownloadState::Downloading;
}
for entry in self.streams.iter() {
entry.value().set_download_state(DownloadState::Downloading);
}
}

pub fn get_download_state(&self) -> DownloadState {
*self.download_state.lock().unwrap()
}

pub fn is_paused(&self) -> bool {
self.get_download_state() == DownloadState::Paused
}

pub fn is_downloading(&self) -> bool {
self.get_download_state() == DownloadState::Downloading
}

pub fn is_init(&self) -> bool {
self.get_download_state() == DownloadState::Init
}

pub async fn add_torrent(
&self,
add_torrent: AddTorrentOptions,
Expand All @@ -91,6 +146,7 @@ impl Session {
peer_states,
have_broadcast.clone(),
pr_rx.clone(),
self.download_state.clone(),
);

let pieces_of_work = (0..(torrent.piece_hashes.len()) as u64)
Expand Down Expand Up @@ -128,6 +184,9 @@ impl Session {
self.streams
.insert(torrent.info_hash, tracker_stream.clone());

// Start downloading
self.start_downloading();

Ok(AddTorrentResult {
torrent,
torrent_meta,
Expand Down
98 changes: 76 additions & 22 deletions crates/bit_rev/src/tracker_peers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde_bencode::de;
use std::sync::{atomic::AtomicBool, Arc, Mutex};
use tokio::{select, sync::Semaphore};
use tokio::{select, sync::Semaphore, time::sleep};
use tracing::{debug, error};

use crate::{
Expand All @@ -11,7 +11,7 @@ use crate::{
},
peer_state::PeerStates,
protocol_udp::request_udp_peers,
session::{PieceResult, PieceWork},
session::{DownloadState, PieceResult, PieceWork},
};

#[derive(Debug, Clone)]
Expand All @@ -23,6 +23,7 @@ pub struct TrackerPeers {
pub piece_rx: flume::Receiver<FullPiece>,
pub pr_rx: flume::Receiver<PieceResult>,
pub have_broadcast: Arc<tokio::sync::broadcast::Sender<u32>>,
pub download_state: Arc<Mutex<DownloadState>>,
}

impl TrackerPeers {
Expand All @@ -33,6 +34,7 @@ impl TrackerPeers {
peer_states: Arc<PeerStates>,
have_broadcast: Arc<tokio::sync::broadcast::Sender<u32>>,
pr_rx: flume::Receiver<PieceResult>,
download_state: Arc<Mutex<DownloadState>>,
) -> TrackerPeers {
let (sender, receiver) = flume::unbounded();
TrackerPeers {
Expand All @@ -43,9 +45,31 @@ impl TrackerPeers {
pr_rx,
peer_states,
have_broadcast,
download_state,
}
}

pub fn set_download_state(&self, state: DownloadState) {
let mut current_state = self.download_state.lock().unwrap();
*current_state = state;
}

pub fn get_download_state(&self) -> DownloadState {
*self.download_state.lock().unwrap()
}

pub fn is_paused(&self) -> bool {
self.get_download_state() == DownloadState::Paused
}

pub fn is_downloading(&self) -> bool {
self.get_download_state() == DownloadState::Downloading
}

pub fn is_init(&self) -> bool {
self.get_download_state() == DownloadState::Init
}

pub async fn connect(&self, pieces_of_work: Vec<PieceWork>) {
let info_hash = self.torrent_meta.info_hash;
let peer_id = self.peer_id;
Expand All @@ -70,6 +94,7 @@ impl TrackerPeers {
let peer_states = self.peer_states.clone();
let piece_tx = self.piece_tx.clone();
let have_broadcast = self.have_broadcast.clone();
let download_state = self.download_state.clone();
let torrent_downloaded_state = Arc::new(TorrentDownloadedState {
semaphore: Semaphore::new(1),
pieces: pieces_of_work
Expand All @@ -84,13 +109,21 @@ impl TrackerPeers {
});
tokio::spawn(async move {
loop {
// Wait while not downloading
while {
let state = *download_state.lock().unwrap();
state != DownloadState::Downloading
} {
sleep(std::time::Duration::from_millis(100)).await;
}
// Handle TCP trackers
for tracker in tcp_trackers.clone() {
let torrent_meta = torrent_meta.clone();
let peer_states = peer_states.clone();
let piece_tx = piece_tx.clone();
let have_broadcast = have_broadcast.clone();
let torrent_downloaded_state = torrent_downloaded_state.clone();
let download_state = download_state.clone();
tokio::spawn(async move {
let url = file::build_tracker_url(&torrent_meta, &peer_id, 6881, &tracker)
.map_err(|e| {
Expand All @@ -105,12 +138,16 @@ impl TrackerPeers {
Ok(new_peers) => {
process_peers(
new_peers,
info_hash,
peer_id,
peer_states.clone(),
piece_tx.clone(),
have_broadcast.clone(),
torrent_downloaded_state.clone(),
PeerProcessorConfig {
info_hash,
peer_id,
peer_states: peer_states.clone(),
piece_tx: piece_tx.clone(),
have_broadcast: have_broadcast.clone(),
torrent_downloaded_state: torrent_downloaded_state
.clone(),
download_state: download_state.clone(),
},
)
.await;

Expand Down Expand Up @@ -141,6 +178,7 @@ impl TrackerPeers {
let piece_tx = piece_tx.clone();
let have_broadcast = have_broadcast.clone();
let torrent_downloaded_state = torrent_downloaded_state.clone();
let download_state = download_state.clone();
tokio::spawn(async move {
match request_udp_peers(&tracker, &torrent_meta, &peer_id, 6881).await {
Ok(udp_response) => {
Expand All @@ -156,12 +194,15 @@ impl TrackerPeers {

process_peers(
new_peers,
info_hash,
peer_id,
peer_states.clone(),
piece_tx.clone(),
have_broadcast.clone(),
torrent_downloaded_state.clone(),
PeerProcessorConfig {
info_hash,
peer_id,
peer_states: peer_states.clone(),
piece_tx: piece_tx.clone(),
have_broadcast: have_broadcast.clone(),
torrent_downloaded_state: torrent_downloaded_state.clone(),
download_state: download_state.clone(),
},
)
.await;

Expand All @@ -185,24 +226,36 @@ impl TrackerPeers {
}
}

async fn process_peers(
new_peers: Vec<std::net::SocketAddr>,
struct PeerProcessorConfig {
info_hash: [u8; 20],
peer_id: [u8; 20],
peer_states: Arc<PeerStates>,
piece_tx: flume::Sender<FullPiece>,
have_broadcast: Arc<tokio::sync::broadcast::Sender<u32>>,
torrent_downloaded_state: Arc<TorrentDownloadedState>,
) {
download_state: Arc<Mutex<DownloadState>>,
}

async fn process_peers(new_peers: Vec<std::net::SocketAddr>, config: PeerProcessorConfig) {
let info_hash = config.info_hash;
let peer_id = config.peer_id;

for peer in new_peers {
if peer_states.clone().states.contains_key(&peer) {
// Skip processing new peers if not downloading
let current_state = *config.download_state.lock().unwrap();
if current_state != DownloadState::Downloading {
continue;
}

if config.peer_states.clone().states.contains_key(&peer) {
continue;
}

let piece_tx = piece_tx.clone();
let have_broadcast = have_broadcast.clone();
let torrent_downloaded_state = torrent_downloaded_state.clone();
let peer_states = peer_states.clone();
let piece_tx = config.piece_tx.clone();
let have_broadcast = config.have_broadcast.clone();
let torrent_downloaded_state = config.torrent_downloaded_state.clone();
let peer_states = config.peer_states.clone();
let download_state = config.download_state.clone();

tokio::spawn(async move {
let unchoke_notify = tokio::sync::Notify::new();
Expand All @@ -215,6 +268,7 @@ async fn process_peers(
peer_writer_tx.clone(),
peer_states.clone(),
torrent_downloaded_state.clone(),
download_state.clone(),
));

let peer_connection =
Expand Down