Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 35 additions & 225 deletions internal/fetch/progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,38 @@ import (
"fmt"
"io"
"strconv"
"sync"
"time"

"github.com/ryanfowler/fetch/internal/core"
"github.com/ryanfowler/fetch/internal/progress"
)

// progressBar is a wrapper around an io.Reader that displays a progress bar
// to stderr. When reading is complete, the Close method MUST be called.
// progressBar wraps a progress.Bar with fetch-specific close behavior
// (native progress emission and download summary).
type progressBar struct {
r io.Reader
printer *core.Printer
bytesRead int64
totalBytes int64
chRead chan int64
start time.Time
wg sync.WaitGroup
bar *progress.Bar
printer *core.Printer
}

func newProgressBar(r io.Reader, p *core.Printer, totalBytes int64) *progressBar {
pr := &progressBar{
r: r,
printer: p,
totalBytes: totalBytes,
chRead: make(chan int64, 1),
start: time.Now(),
var onRender func(int64)
if core.IsStdoutTerm {
onRender = func(pct int64) {
emitProgress(1, int(pct), p)
}
}
return &progressBar{
bar: progress.NewBar(r, p, totalBytes, onRender),
printer: p,
}
pr.wg.Add(1)
go pr.renderLoop()
return pr
}

func (pb *progressBar) Read(p []byte) (int, error) {
n, err := pb.r.Read(p)
if n > 0 {
pb.chRead <- int64(n)
}
return n, err
return pb.bar.Read(p)
}

func (pb *progressBar) Close(path string, err error) {
// Close the reader channel and wait for the loop to exit.
close(pb.chRead)
pb.wg.Wait()
bytesRead, elapsed := pb.bar.Stop()

p := pb.printer

Expand All @@ -60,122 +49,37 @@ func (pb *progressBar) Close(path string, err error) {
p.WriteString("\n\n")
} else {
// Replace the progress bar with a summary.
writeFinalProgress(p, pb.bytesRead, time.Since(pb.start), 32, path)
writeFinalProgress(p, bytesRead, elapsed, 32, path)
}
p.Flush()
}

func (pb *progressBar) renderLoop() {
defer pb.wg.Done()

lastUpdateTime := pb.start
var chTimeout <-chan time.Time
for {
select {
case <-chTimeout:
chTimeout = nil
case n, ok := <-pb.chRead:
if !ok {
// Reader channel has been closed, exit.
pb.render()
return
}
pb.bytesRead += n

if chTimeout != nil {
// We're waiting on a timeout to re-render.
continue
}

// Check if enough time has passed since the last
// render. If not, set a timeout and continue.
now := time.Now()
dur := lastUpdateTime.Add(100 * time.Millisecond).Sub(now)
if dur > 0 {
chTimeout = time.After(dur)
continue
}
lastUpdateTime = now
}

pb.render()
}
// progressSpinner wraps a progress.Spinner with fetch-specific close behavior
// (native progress emission and download summary).
type progressSpinner struct {
spinner *progress.Spinner
printer *core.Printer
}

func (pb *progressBar) render() {
const barWidth = 30
percentage := pb.bytesRead * 100 / pb.totalBytes
completedWidth := min(barWidth*percentage/100, barWidth)

p := pb.printer

// Render native progress bar.
func newProgressSpinner(r io.Reader, p *core.Printer) *progressSpinner {
var onStart func()
if core.IsStdoutTerm {
emitProgress(1, int(percentage), p)
}

p.WriteString("\r")

p.Set(core.Bold)
p.WriteString("[")
p.Set(core.Green)
for range completedWidth {
p.WriteString("=")
}
p.Reset()
for range barWidth - completedWidth {
p.WriteString(" ")
}
p.Set(core.Bold)
p.WriteString("] ")

pctStr := strconv.FormatInt(percentage, 10)
for i := len(pctStr); i < 3; i++ {
p.WriteString(" ")
onStart = func() {
emitProgress(3, 0, p)
}
}
p.WriteString(pctStr)
p.WriteString("%")
p.Reset()

p.WriteString(" (")
size := formatSize(pb.bytesRead)
for range 7 - len(size) {
p.WriteString(" ")
return &progressSpinner{
spinner: progress.NewSpinner(r, p, onStart),
printer: p,
}
p.WriteString(size)
p.WriteString(" / ")
p.WriteString(formatSize(pb.totalBytes))
p.WriteString(")")
p.Flush()
}

// progressSpinner is a wrapper around an io.Reader that displays a progress
// spinner to stderr. When reading is complete, the Close method MUST be called.
type progressSpinner struct {
r io.Reader
printer *core.Printer
bytesRead int64
chRead chan int64
position int64
wg sync.WaitGroup
start time.Time
}

func newProgressSpinner(r io.Reader, p *core.Printer) *progressSpinner {
ps := &progressSpinner{
r: r,
printer: p,
chRead: make(chan int64, 1),
start: time.Now(),
}
ps.wg.Add(1)
go ps.renderLoop()
return ps
func (ps *progressSpinner) Read(p []byte) (int, error) {
return ps.spinner.Read(p)
}

func (ps *progressSpinner) Close(path string, err error) {
close(ps.chRead)
ps.wg.Wait()
bytesRead, elapsed := ps.spinner.Stop()

p := ps.printer

Expand All @@ -188,86 +92,11 @@ func (ps *progressSpinner) Close(path string, err error) {
p.WriteString("\n\n")
} else {
// Replace the progress spinner with a summary.
writeFinalProgress(p, ps.bytesRead, time.Since(ps.start), 20, path)
writeFinalProgress(p, bytesRead, elapsed, 20, path)
}
p.Flush()
}

func (ps *progressSpinner) Read(p []byte) (int, error) {
n, err := ps.r.Read(p)
if n > 0 {
ps.chRead <- int64(n)
}
return n, err
}

func (ps *progressSpinner) renderLoop() {
defer ps.wg.Done()

// Render native progress bar.
if core.IsStdoutTerm {
emitProgress(3, 0, ps.printer)
}

ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ps.render()
ps.position++
case n, ok := <-ps.chRead:
if !ok {
// Reader channel has been closed, exit.
ps.render()
return
}
ps.bytesRead += n
}
}
}

func (ps *progressSpinner) render() {
const width = 20

var value string
var offset int
position := ps.position % (width * 2)
if position < width {
value = "=>"
offset = int(position)
} else {
value = "<="
offset = int(width*2 - position - 1)
}

p := ps.printer
p.WriteString("\r")
p.Set(core.Bold)
p.WriteString("[")
for range offset {
p.WriteString(" ")
}
p.Set(core.Green)
p.WriteString(value)
p.Reset()
for range width - offset - 1 {
p.WriteString(" ")
}
p.Set(core.Bold)
p.WriteString("]")
p.Reset()

p.WriteString(" ")
size := formatSize(ps.bytesRead)
for range 7 - len(size) {
p.WriteString(" ")
}
p.WriteString(size)

p.Flush()
}

type progressStatic struct {
r io.Reader
printer *core.Printer
Expand Down Expand Up @@ -299,25 +128,6 @@ func (ps *progressStatic) Close(path string, err error) {
ps.printer.Flush()
}

// formatSize converts bytes to a human-readable string.
func formatSize(bytes int64) string {
const units = "KMGTPE"
const unit = 1024
if bytes < unit {
return strconv.FormatInt(bytes, 10) + "B"
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= 1000; n /= unit {
div *= unit
exp++
}
value := float64(bytes) / float64(div)
if exp >= len(units) {
return "NaN"
}
return strconv.FormatFloat(value, 'f', 1, 64) + string(units[exp]) + "B"
}

func formatDuration(d time.Duration) string {
switch {
case d < time.Second:
Expand All @@ -338,7 +148,7 @@ func writeFinalProgress(p *core.Printer, bytesRead int64, dur time.Duration, toC

p.WriteString("Downloaded ")
p.Set(core.Bold)
p.WriteString(formatSize(bytesRead))
p.WriteString(progress.FormatSize(bytesRead))
p.Reset()
p.WriteString(" in ")
p.Set(core.Italic)
Expand Down
Loading