diff --git a/app/dl/dl.go b/app/dl/dl.go index e603250ea..b50b8d9ed 100644 --- a/app/dl/dl.go +++ b/app/dl/dl.go @@ -45,6 +45,10 @@ type Options struct { // serve Serve bool Port int + + // size filter + MinSize string + MaxSize string } type parser struct { diff --git a/app/dl/iter.go b/app/dl/iter.go index 44503e248..6e3933266 100644 --- a/app/dl/iter.go +++ b/app/dl/iter.go @@ -53,6 +53,10 @@ type iter struct { opts Options delay time.Duration + // size filter + minSize int64 + maxSize int64 + mu *sync.Mutex finished map[int]struct{} fingerprint string @@ -90,6 +94,16 @@ func newIter(pool dcpool.Pool, manager *peers.Manager, dialog [][]*tmessage.Dial includeMap := filterMap.New(opts.Include, fsutil.AddPrefixDot) excludeMap := filterMap.New(opts.Exclude, fsutil.AddPrefixDot) + minSize, err := utils.Byte.ParseBinaryBytes(opts.MinSize) + if err != nil { + return nil, errors.Wrap(err, "parse min size") + } + + maxSize, err := utils.Byte.ParseBinaryBytes(opts.MaxSize) + if err != nil { + return nil, errors.Wrap(err, "parse max size") + } + // to keep fingerprint stable sortDialogs(dialogs, opts.Desc) @@ -102,6 +116,8 @@ func newIter(pool dcpool.Pool, manager *peers.Manager, dialog [][]*tmessage.Dial exclude: excludeMap, tpl: tpl, delay: delay, + minSize: minSize, + maxSize: maxSize, mu: &sync.Mutex{}, finished: make(map[int]struct{}), @@ -216,6 +232,11 @@ func (i *iter) processSingle(ctx context.Context, message *tg.Message, from peer return false, true } + // check size + if i.shouldSkip(ctx, item.Size) { + return false, true + } + // process include and exclude ext := filepath.Ext(item.Name) if _, ok = i.include[ext]; len(i.include) > 0 && !ok { @@ -427,3 +448,23 @@ func fingerprint(dialogs []*tmessage.Dialog) string { return fmt.Sprintf("%x", sha256.Sum256(buf.Bytes())) } + +func (i *iter) shouldSkip(ctx context.Context, size int64) bool { + if i.minSize > 0 && size < i.minSize { + logctx.From(ctx).Debug("Skip file due to min-size limit", + zap.Int64("size", size), + zap.Int64("min_size", i.minSize), + ) + return true + } + + if i.maxSize > 0 && size > i.maxSize { + logctx.From(ctx).Debug("Skip file due to max-size limit", + zap.Int64("size", size), + zap.Int64("max_size", i.maxSize), + ) + return true + } + + return false +} diff --git a/app/dl/iter_test.go b/app/dl/iter_test.go index 97deadfe1..d340b6cfd 100644 --- a/app/dl/iter_test.go +++ b/app/dl/iter_test.go @@ -229,3 +229,34 @@ func TestIterContextCancellation(t *testing.T) { } }) } + +func TestIterShouldSkip(t *testing.T) { + tests := []struct { + name string + minSize int64 + maxSize int64 + size int64 + want bool + }{ + {name: "no limit", minSize: 0, maxSize: 0, size: 100, want: false}, + {name: "min limit - pass", minSize: 50, maxSize: 0, size: 100, want: false}, + {name: "min limit - skip", minSize: 150, maxSize: 0, size: 100, want: true}, + {name: "max limit - pass", minSize: 0, maxSize: 150, size: 100, want: false}, + {name: "max limit - skip", minSize: 0, maxSize: 50, size: 100, want: true}, + {name: "both limits - pass", minSize: 50, maxSize: 150, size: 100, want: false}, + {name: "both limits - skip min", minSize: 150, maxSize: 200, size: 100, want: true}, + {name: "both limits - skip max", minSize: 50, maxSize: 80, size: 100, want: true}, + {name: "exact min - pass", minSize: 100, maxSize: 0, size: 100, want: false}, + {name: "exact max - pass", minSize: 0, maxSize: 100, size: 100, want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &iter{ + minSize: tt.minSize, + maxSize: tt.maxSize, + } + assert.Equal(t, tt.want, i.shouldSkip(context.Background(), tt.size)) + }) + } +} diff --git a/cmd/dl.go b/cmd/dl.go index 6540ee8e9..01af9b467 100644 --- a/cmd/dl.go +++ b/cmd/dl.go @@ -61,6 +61,9 @@ func NewDownload() *cobra.Command { cmd.Flags().BoolVar(&opts.Takeout, "takeout", false, "takeout sessions let you export data from your account with lower flood wait limits.") cmd.Flags().BoolVar(&opts.Group, "group", false, "auto detect grouped message and download all of them") + cmd.Flags().StringVar(&opts.MinSize, "min-size", "", "min size of file to download. Example: 10MB, 1GB") + cmd.Flags().StringVar(&opts.MaxSize, "max-size", "", "max size of file to download. Example: 10MB, 1GB") + // resume flags, if both false then ask user cmd.Flags().BoolVar(&opts.Continue, _continue, false, "continue the last download directly") cmd.Flags().BoolVar(&opts.Restart, restart, false, "restart the last download directly") diff --git a/pkg/utils/byte.go b/pkg/utils/byte.go index d1c14c0ad..48eb86ca1 100644 --- a/pkg/utils/byte.go +++ b/pkg/utils/byte.go @@ -1,6 +1,10 @@ package utils -import "fmt" +import ( + "fmt" + "strconv" + "strings" +) type _byte struct{} @@ -21,3 +25,36 @@ func (b _byte) FormatBinaryBytes(n int64) string { } return fmt.Sprintf("%.2f TB", float64(n)/1024/1024/1024/1024) } + +func (b _byte) ParseBinaryBytes(s string) (int64, error) { + if s == "" { + return 0, nil + } + + s = strings.TrimSpace(s) + s = strings.ToUpper(s) + + var multiplier int64 = 1 + if strings.HasSuffix(s, "TB") { + multiplier = 1024 * 1024 * 1024 * 1024 + s = strings.TrimSuffix(s, "TB") + } else if strings.HasSuffix(s, "GB") { + multiplier = 1024 * 1024 * 1024 + s = strings.TrimSuffix(s, "GB") + } else if strings.HasSuffix(s, "MB") { + multiplier = 1024 * 1024 + s = strings.TrimSuffix(s, "MB") + } else if strings.HasSuffix(s, "KB") { + multiplier = 1024 + s = strings.TrimSuffix(s, "KB") + } else if strings.HasSuffix(s, "B") { + s = strings.TrimSuffix(s, "B") + } + + val, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0, err + } + + return int64(val * float64(multiplier)), nil +} diff --git a/pkg/utils/byte_test.go b/pkg/utils/byte_test.go new file mode 100644 index 000000000..8506a711a --- /dev/null +++ b/pkg/utils/byte_test.go @@ -0,0 +1,38 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_byte_ParseBinaryBytes(t *testing.T) { + tests := []struct { + name string + s string + want int64 + wantErr bool + }{ + {name: "bytes", s: "100B", want: 100, wantErr: false}, + {name: "bytes lower", s: "100b", want: 100, wantErr: false}, + {name: "kilobytes", s: "1KB", want: 1024, wantErr: false}, + {name: "kilobytes lower", s: "1kb", want: 1024, wantErr: false}, + {name: "megabytes", s: "10MB", want: 10 * 1024 * 1024, wantErr: false}, + {name: "gigabytes", s: "1.5GB", want: int64(1.5 * 1024 * 1024 * 1024), wantErr: false}, + {name: "raw number", s: "100", want: 100, wantErr: false}, + {name: "invalid unit", s: "100ZB", want: 0, wantErr: true}, + {name: "invalid format", s: "abc", want: 0, wantErr: true}, + {name: "empty", s: "", want: 0, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Byte.ParseBinaryBytes(tt.s) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +}