@@ -6,10 +6,10 @@ mod tests;
66
77use std:: path:: Path ;
88
9- use anyhow:: { Context , Result , anyhow, bail} ;
10- use futures_util:: stream:: StreamExt ;
9+ use anyhow:: { Context , Error , Result , anyhow, bail} ;
10+ use futures_util:: { future :: join , stream:: StreamExt } ;
1111use std:: sync:: Arc ;
12- use tokio:: sync:: Semaphore ;
12+ use tokio:: sync:: { Semaphore , mpsc } ;
1313use tracing:: info;
1414
1515use crate :: dist:: component:: {
@@ -153,7 +153,6 @@ impl Manifestation {
153153 let altered = tmp_cx. dist_server != DEFAULT_DIST_SERVER ;
154154
155155 // Download component packages and validate hashes
156- let mut things_to_install: Vec < ( Component , CompressionKind , File ) > = Vec :: new ( ) ;
157156 let mut things_downloaded: Vec < String > = Vec :: new ( ) ;
158157 let components = update. components_urls_and_hashes ( new_manifest) ?;
159158 let components_len = components. len ( ) ;
@@ -172,49 +171,7 @@ impl Manifestation {
172171 . and_then ( |s| s. parse ( ) . ok ( ) )
173172 . unwrap_or ( DEFAULT_MAX_RETRIES ) ;
174173
175- info ! ( "downloading component(s)" ) ;
176- for ( component, _, url, _) in components. clone ( ) {
177- ( download_cfg. notify_handler ) ( Notification :: DownloadingComponent (
178- & component. short_name ( new_manifest) ,
179- & self . target_triple ,
180- component. target . as_ref ( ) ,
181- & url,
182- ) ) ;
183- }
184-
185- let semaphore = Arc :: new ( Semaphore :: new ( concurrent_downloads) ) ;
186- let component_stream =
187- tokio_stream:: iter ( components. into_iter ( ) ) . map ( |( component, format, url, hash) | {
188- let sem = semaphore. clone ( ) ;
189- async move {
190- let _permit = sem. acquire ( ) . await . unwrap ( ) ;
191- self . download_component (
192- component,
193- format,
194- url,
195- hash,
196- altered,
197- tmp_cx,
198- download_cfg,
199- max_retries,
200- new_manifest,
201- )
202- . await
203- }
204- } ) ;
205- if components_len > 0 {
206- let results = component_stream
207- . buffered ( components_len)
208- . collect :: < Vec < _ > > ( )
209- . await ;
210- for result in results {
211- let ( component, format, downloaded_file, hash) = result?;
212- things_downloaded. push ( hash) ;
213- things_to_install. push ( ( component, format, downloaded_file) ) ;
214- }
215- }
216-
217- // Begin transaction
174+ // Begin transaction before the downloads, as installations are interleaved with those
218175 let mut tx = Transaction :: new (
219176 prefix. clone ( ) ,
220177 tmp_cx,
@@ -226,6 +183,16 @@ impl Manifestation {
226183 // to uninstall it first.
227184 tx = self . maybe_handle_v2_upgrade ( & config, tx, download_cfg. process ) ?;
228185
186+ info ! ( "downloading component(s)" ) ;
187+ for ( component, _, url, _) in components. clone ( ) {
188+ ( download_cfg. notify_handler ) ( Notification :: DownloadingComponent (
189+ & component. short_name ( new_manifest) ,
190+ & self . target_triple ,
191+ component. target . as_ref ( ) ,
192+ & url,
193+ ) ) ;
194+ }
195+
229196 // Uninstall components
230197 for component in & update. components_to_uninstall {
231198 let notification = if implicit_modify {
@@ -248,17 +215,79 @@ impl Manifestation {
248215 ) ?;
249216 }
250217
251- // Install components
252- for ( component, format, installer_file) in things_to_install {
253- tx = self . install_component (
254- component,
255- format,
256- installer_file,
257- tmp_cx,
258- download_cfg,
259- new_manifest,
260- tx,
261- ) ?;
218+ if components_len > 0 {
219+ // Create a channel to communicate whenever a download is done and the component can be installed
220+ // The `mpsc` channel was used as we need to send many messages from one producer (download's thread) to one consumer (install's thread)
221+ // This is recommended in the official docs: https://docs.rs/tokio/latest/tokio/sync/index.html#mpsc-channel
222+ let total_components = components. len ( ) ;
223+ let ( download_tx, mut download_rx) =
224+ mpsc:: channel :: < Result < ( Component , CompressionKind , File ) > > ( total_components) ;
225+
226+ let semaphore = Arc :: new ( Semaphore :: new ( concurrent_downloads) ) ;
227+ let component_stream =
228+ tokio_stream:: iter ( components. into_iter ( ) ) . map ( |( component, format, url, hash) | {
229+ let sem = semaphore. clone ( ) ;
230+ let download_tx_cloned = download_tx. clone ( ) ;
231+ async move {
232+ let _permit = sem. acquire ( ) . await . unwrap ( ) ;
233+ self . download_component (
234+ component,
235+ format,
236+ url,
237+ hash,
238+ altered,
239+ tmp_cx,
240+ download_cfg,
241+ max_retries,
242+ new_manifest,
243+ download_tx_cloned,
244+ )
245+ . await
246+ }
247+ } ) ;
248+
249+ let mut stream = component_stream. buffered ( components_len) ;
250+ let ( download_results, install_result) = join (
251+ async {
252+ let mut hashes = Vec :: new ( ) ;
253+ while let Some ( result) = stream. next ( ) . await {
254+ match result {
255+ Ok ( hash) => {
256+ hashes. push ( hash) ;
257+ }
258+ Err ( e) => {
259+ let _ = download_tx. send ( Err ( e) ) . await ;
260+ }
261+ }
262+ }
263+ hashes
264+ } ,
265+ async {
266+ let mut current_tx = tx;
267+ let mut counter = 0 ;
268+ while counter < total_components
269+ && let Some ( message) = download_rx. recv ( ) . await
270+ {
271+ let ( component, format, installer_file) = message?;
272+ let new_tx = self . install_component (
273+ component. clone ( ) ,
274+ format,
275+ installer_file,
276+ tmp_cx,
277+ download_cfg,
278+ new_manifest,
279+ current_tx,
280+ ) ?;
281+ current_tx = new_tx;
282+ counter += 1 ;
283+ }
284+ Ok :: < _ , Error > ( current_tx)
285+ } ,
286+ )
287+ . await ;
288+
289+ things_downloaded = download_results;
290+ tx = install_result?;
262291 }
263292
264293 // Install new distribution manifest
@@ -510,7 +539,8 @@ impl Manifestation {
510539 download_cfg : & DownloadCfg < ' _ > ,
511540 max_retries : usize ,
512541 new_manifest : & Manifest ,
513- ) -> Result < ( Component , CompressionKind , File , String ) > {
542+ notification_tx : mpsc:: Sender < Result < ( Component , CompressionKind , File ) > > ,
543+ ) -> Result < String > {
514544 use tokio_retry:: { RetryIf , strategy:: FixedInterval } ;
515545
516546 let url = if altered {
@@ -539,9 +569,13 @@ impl Manifestation {
539569 . await
540570 . with_context ( || RustupError :: ComponentDownloadFailed ( component. name ( new_manifest) ) ) ?;
541571
542- Ok ( ( component, format, downloaded_file, hash) )
572+ let _ = notification_tx
573+ . send ( Ok ( ( component. clone ( ) , format, downloaded_file) ) )
574+ . await ;
575+ Ok ( hash)
543576 }
544577
578+ #[ allow( clippy:: too_many_arguments) ]
545579 fn install_component < ' a > (
546580 & self ,
547581 component : Component ,
0 commit comments