diff --git a/Cargo.lock b/Cargo.lock index bbdbcb3..dfe0b73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -230,6 +230,7 @@ checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" name = "cli" version = "0.1.0" dependencies = [ + "anyhow", "bit_rev", "console-subscriber", "flume", diff --git a/crates/bit_rev/src/session.rs b/crates/bit_rev/src/session.rs index 166d177..b7f5e60 100644 --- a/crates/bit_rev/src/session.rs +++ b/crates/bit_rev/src/session.rs @@ -1,8 +1,11 @@ use std::sync::Arc; +use crate::file::{self, TorrentMeta}; +use crate::peer_state::PeerStates; use crate::torrent::Torrent; use crate::tracker_peers::TrackerPeers; use crate::utils; +use dashmap::DashMap; use flume::Receiver; #[derive(Debug, Clone, Copy)] @@ -27,19 +30,68 @@ pub struct State { } pub struct Session { - pub tracker_stream: TrackerPeers, + pub streams: DashMap<[u8; 20], TrackerPeers>, +} + +pub struct AddTorrentOptions { + torrent_meta: TorrentMeta, +} + +impl AddTorrentOptions { + fn from_meta(torrent_meta: TorrentMeta) -> Self { + Self { torrent_meta } + } + + fn from_path(path: &str) -> Self { + let torrent_meta = file::from_filename(path).unwrap(); + Self { torrent_meta } + } +} + +impl From for AddTorrentOptions { + fn from(torrent_meta: TorrentMeta) -> Self { + Self::from_meta(torrent_meta) + } +} + +impl From<&str> for AddTorrentOptions { + fn from(path: &str) -> Self { + Self::from_path(path) + } +} + +pub struct AddTorrentResult { + pub torrent: Torrent, + pub torrent_meta: TorrentMeta, pub pr_rx: Receiver, } impl Session { - pub async fn download_torrent( - torrent: Torrent, - tracker_stream: TrackerPeers, - have_broadcast: Arc>, - ) -> Self { - let piece_rx = tracker_stream.piece_rx.clone(); + pub fn new() -> Self { + Self { + streams: DashMap::new(), + } + } + + pub async fn add_torrent( + &self, + add_torrent: AddTorrentOptions, + ) -> anyhow::Result { + let torrent = Torrent::new(&add_torrent.torrent_meta.clone()); + let torrent_meta = add_torrent.torrent_meta.clone(); let (pr_tx, pr_rx) = flume::bounded::(torrent.piece_hashes.len()); - //let (pr_tx, pr_rx) = flume::unbounded::(); + let have_broadcast = Arc::new(tokio::sync::broadcast::channel(128).0); + let peer_states = Arc::new(PeerStates::default()); + let random_peers = utils::generate_peer_id(); + + let tracker_stream = TrackerPeers::new( + torrent_meta.clone(), + 15, + random_peers, + peer_states, + have_broadcast.clone(), + pr_rx.clone(), + ); let pieces_of_work = (0..(torrent.piece_hashes.len()) as u64) .map(|index| { @@ -55,6 +107,7 @@ impl Session { tracker_stream.connect(pieces_of_work).await; let have_broadcast = have_broadcast.clone(); + let piece_rx = tracker_stream.piece_rx.clone(); tokio::spawn(async move { loop { @@ -72,9 +125,19 @@ impl Session { } }); - Self { - tracker_stream, + self.streams + .insert(torrent.info_hash, tracker_stream.clone()); + + Ok(AddTorrentResult { + torrent, + torrent_meta, pr_rx, - } + }) + } +} + +impl Default for Session { + fn default() -> Self { + Self::new() } } diff --git a/crates/bit_rev/src/tracker_peers.rs b/crates/bit_rev/src/tracker_peers.rs index c2ed9af..8ff12d4 100644 --- a/crates/bit_rev/src/tracker_peers.rs +++ b/crates/bit_rev/src/tracker_peers.rs @@ -11,7 +11,7 @@ use crate::{ }, peer_state::PeerStates, protocol_udp::request_udp_peers, - session::PieceWork, + session::{PieceResult, PieceWork}, }; #[derive(Debug, Clone)] @@ -21,6 +21,7 @@ pub struct TrackerPeers { pub peer_states: Arc, pub piece_tx: flume::Sender, pub piece_rx: flume::Receiver, + pub pr_rx: flume::Receiver, pub have_broadcast: Arc>, } @@ -31,6 +32,7 @@ impl TrackerPeers { peer_id: [u8; 20], peer_states: Arc, have_broadcast: Arc>, + pr_rx: flume::Receiver, ) -> TrackerPeers { let (sender, receiver) = flume::unbounded(); TrackerPeers { @@ -38,6 +40,7 @@ impl TrackerPeers { peer_id, piece_tx: sender, piece_rx: receiver, + pr_rx, peer_states, have_broadcast, } diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index a77dac3..cba89ce 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -12,6 +12,7 @@ default = [] tokio-console = ["console-subscriber"] [dependencies] +anyhow.workspace = true tokio.workspace = true bit_rev.workspace = true indicatif.workspace = true diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 87bee28..42ddb75 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -10,13 +10,7 @@ use tokio::{ }; use tracing::trace; -use bit_rev::{ - file::{self, TorrentMeta}, - session::Session, - torrent::Torrent, - tracker_peers::TrackerPeers, - utils, -}; +use bit_rev::{session::Session, utils}; #[tokio::main] async fn main() { @@ -29,34 +23,17 @@ async fn main() { let filename = std::env::args().nth(1).expect("No torrent path given"); let output = std::env::args().nth(2); - let torrent_meta = file::from_filename(&filename).unwrap(); - - download_file(torrent_meta, output).await + if let Err(err) = download_file(&filename, output).await { + eprintln!("Error: {:?}", err); + } } -pub async fn download_file(torrent_meta: TorrentMeta, out_file: Option) { - let random_peers = utils::generate_peer_id(); +pub async fn download_file(filename: &str, out_file: Option) -> anyhow::Result<()> { + let session = Session::new(); - let torrent = Torrent::new(&torrent_meta.clone()); - - let peer_states = Arc::new(bit_rev::peer_state::PeerStates::default()); - let (have_broadcast, _) = tokio::sync::broadcast::channel(128); - let have_broadcast = Arc::new(have_broadcast); - - //TODO: move it to a download manager state - let tracker_stream = TrackerPeers::new( - torrent_meta.clone(), - 15, - random_peers, - peer_states, - have_broadcast.clone(), - ); - - //TODO: I think this is really bad - - //TODO: return more than just the buffer - let downloader = - Session::download_torrent(torrent.clone(), tracker_stream.clone(), have_broadcast).await; + let add_torrent_result = session.add_torrent(filename.into()).await?; + let torrent = add_torrent_result.torrent.clone(); + let torrent_meta = add_torrent_result.torrent_meta; let total_size = torrent.length as u64; let pb = ProgressBar::new(total_size); @@ -74,7 +51,7 @@ pub async fn download_file(torrent_meta: TorrentMeta, out_file: Option) Some(name) => name, None => torrent_meta.clone().torrent_file.info.name.clone(), }; - let mut file = File::create(out_filename).await.unwrap(); + let mut file = File::create(out_filename).await?; // File let total_downloaded = Arc::new(AtomicU64::new(0)); @@ -91,7 +68,7 @@ pub async fn download_file(torrent_meta: TorrentMeta, out_file: Option) let mut hashset = std::collections::HashSet::new(); while hashset.len() < torrent.piece_hashes.len() { - let pr = downloader.pr_rx.recv_async().await.unwrap(); + let pr = add_torrent_result.pr_rx.recv_async().await?; hashset.insert(pr.index); let (start, end) = utils::calculate_bounds_for_piece(&torrent, pr.index as usize); @@ -102,11 +79,13 @@ pub async fn download_file(torrent_meta: TorrentMeta, out_file: Option) end, pr.length ); - file.seek(SeekFrom::Start(start as u64)).await.unwrap(); - file.write_all(pr.buf.as_slice()).await.unwrap(); + file.seek(SeekFrom::Start(start as u64)).await?; + file.write_all(pr.buf.as_slice()).await?; total_downloaded.fetch_add(pr.length as u64, std::sync::atomic::Ordering::Relaxed); } - file.sync_all().await.unwrap() + file.sync_all().await?; + + Ok(()) }