Skip to content

Conversation

Dan-Flores
Copy link
Contributor

@Dan-Flores Dan-Flores commented Sep 17, 2025

By passing in an irrational value to torch.arange, a floating point error can sometimes occur, leading to sampling_range_end being included in the clip timestamps. This is an issue when sampling_range_end is filtered out in _build_all_clips_timestamps, leaving an empty list that a policy may operate on.

To avoid this, this PR drops the last value of clip_start_seconds when a floating point error occurs, and sets the final value to sampling_range_end.

The updated approach checks if the last value of clip_start_seconds is greater or equal to sampling_range_end. If it is, a boolean mask is applied to only keep any values less than to sampling_range_end.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 17, 2025
sampling_range_start,
sampling_range_end, # excluded
seconds_between_clip_starts,
round(seconds_between_clip_starts, 6),
Copy link
Contributor

@scotts scotts Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, wow, this is subtle. Apparently there's also a warning in the docs: https://docs.pytorch.org/docs/2.8/generated/torch.arange.html

What I'm not clear on is why does the explicit rounding work, and why do we expect it to always work? I'm wondering if it's safer to do something like:

clip_start_seconds = torch.arange(
    sampling_range_start,
    sampling_range_end,  # excluded
    seconds_between_clip_starts,
)

# 1. Is it okay for `clip_start_seconds` to be empty? The code
#    below allows for and checks for it.
# 2. Regardless of 1, we should explain here why we're doing it,
#    with a link to the above docs.
if clip_start_seconds and not clip_start_seconds[-1]:
    clip_start_seconds = clip_start_seconds[:-1]

num_clips = len(clip_start_seconds)

I'm genuinely unsure what's the best approach. @Dan-Flores, thoughts?

Edit: note I initially left out the not above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the explicit rounding work, and why do we expect it to always work?

By using the rounding approach, my thinking was that the floating-point precision errors would not be reached, but I can imagine scenarios where this is not true.
In this case, I believe the number of times the floating-point step is used is low enough that the accumulated error in the last clip_start_seconds is not rounded up to 1, but just below 1, so it is not filtered out of the clip's timestamps.

Is it okay for clip_start_seconds to be empty?

I don't think clip_start_seconds should be empty, since we assert that the start and end times are valid in _validate_sampling_range_time_based.

we should explain here why we're doing it, # with a link to the above docs.

Agreed, let's add the documentation here.
Since the docs suggest the small epsilon approach, I think that approach is the most reasonable at this time.
I'm not sure we always want to remove the last item from the clip_start_seconds list, especially when users do not provide a floating-point step.

Copy link
Contributor

@scotts scotts Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I believe the number of times the floating-point step is used is low enough that the accumulated error in the last clip_start_seconds is not rounded up to 1, but just below 1, so it is not filtered out of the clip's timestamps.

Ah, got it.

I'm not sure we always want to remove the last item from the clip_start_seconds list, especially when users do not provide a floating-point step.

Agreed, but that's not what this code does:

if clip_start_seconds and not clip_start_seconds[-1]:
    clip_start_seconds = clip_start_seconds[:-1]

First we check if clip_start_seconds is empty. If it is not, we short-circuit out. I think we can probably eliminate that check, though, based on what you said. Then we check not clip_start_seconds[-1]. That will return true if the last element of clip_start_seconds is empty, and false otherwise. We can probably shorten it to:

if not clip_start_seconds[-1]:
    clip_start_seconds = clip_start_seconds[:-1]

Rather than trying to massage the input to torch.arange(), we're instead massaging the output. That is, we explicitly check if the last element of the list is itself an empty list, and if it is, we remove it. It's easier for me to convince myself that will work all the time, rather than trying to tweak the input to avoid the rounding errors in torch.arange().

Copy link
Contributor Author

@Dan-Flores Dan-Flores Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, my mistake! We can use this approach, but one slight adjustment is that a value should be removed from torch.arange when the last value is equal to sampling_range_end.

if clip_start_seconds[-1] == sampling_range_end: 
	clip_start_seconds = clip_start_seconds[:-1]

The conversion to a clip with an empty timestamps list occurs in _build_all_clips_timestamps, but we should remove the bad clip_start_seconds before we calculate num_clips, so at the same location as your code snippet.

Let me update the PR to use this approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Dan-Flores, in our in-person conversation, we concluded doing exact equality was only safe if we knew for certain that clip_start_seconds[-1] would never be sampling_range_end + epsilon. Looking at the docs, I'm not sure if we can conclude that. The whole note (emphasis mine):

Note: When using floating-point dtypes (especially reduced precision types like bfloat16), the results may be affected by floating-point rounding behavior. Some values in the sequence might not be exactly representable in certain floating-point formats, which can lead to repeated values or unexpected rounding. For precise sequences, it is recommended to use integer dtypes instead of floating-point dtypes.

Note that non-integer step is subject to floating point rounding errors when comparing against end; to avoid inconsistency, we advise subtracting a small epsilon from end in such cases.

I'm now thinking that we might have to do as they suggest.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which I'm now realizing is more in line with what you originally suggested, but in the other numeric direction. :)

@Dan-Flores Dan-Flores marked this pull request as ready for review September 18, 2025 17:56
@Dan-Flores Dan-Flores force-pushed the fix_sampling_index_error branch from dca4dea to 46a0b5b Compare September 21, 2025 20:58
sampling_range_start,
sampling_range_end, # excluded
seconds_between_clip_starts,
seconds_between_clip_starts - 1e-6,
Copy link
Member

@NicolasHug NicolasHug Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating this and for the fix @Dan-Flores !

With the current changes, we've got some failures on pre-existing tests. This is because this PR is adding eps to step, but the docs of arange suggest adding eps to end, i.e. to sampling_range_end.

Locally, adding eps to sampling_range_end seems to pass all tests, including the new ones you added.

But I'd like to see if we can find a fix that wouldn't involve adding eps (despite the torch docs suggesting that as a workaround). When we were designing and implementing the samplers, I remember we had lot of discussions around precisely avoiding that.

Intuitively, I think the fix could be something like:

        clip_start_seconds = torch.arange(...) # unchanged!
        #  As mentioned in the docs, torch.arange may return values
        # equal or above `end` because of floating precision errors
        # so we have to manually ensure all values are strictly lower than `sample_range_end`
        if clip_start_seconds[-1] >= sampling_range_end:
            clip_start_seconds = clip_start_seconds[:-1]
            # if we're paranoid and suspect there may be more than one value outside
            # of the range, we could do:
            # clip_start_seconds = clip_start_seconds[clip_start_seconds < sampling_range_end]

This also passes all tests locally. I understand you and @scotts already discussed this solution (or a variant of it) and decided not to go for it, but I think using >= instead of == addresses the original concerns you had?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug, your proposed route is what I felt better doing based on first principles - I generally think it's cleaner and easier to verify correcting the output. I waffled when I read the PyTorch docs. :) I'm happy with this approach; it's close to what I suggested already.

I do find myself wondering if we should check for more than one value out of range. Do we have any reason to think it should be limited to just the last returned value?

Copy link
Member

@NicolasHug NicolasHug Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any reason to think it should be limited to just the last returned value?

I assume a natural implementation of range would be such that only the last value may be out of the range, but that's purely speculative, and an implementation detail anyway, I suppose. If we assume there may be more than one value out of the range, we can use clip_start_seconds = clip_start_seconds[clip_start_seconds < sampling_range_end]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how we could have multiple values out of the range, but I am happy with this solution as well, since it effectively covers the case of a single value out of range as well.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the great reproducing example and for the fix @Dan-Flores ! Approving now, assuming the rest of the CI is green

Comment on lines 199 to 201
# The torch.arange documentation warns that floating point rounding errors
# are possible for non-integer steps when comparing to end.
# docs.pytorch.org/docs/2.8/generated/torch.arange.html
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment can probably be removed now that we have the one below

@NicolasHug NicolasHug added the bug Something isn't working label Sep 22, 2025
@Dan-Flores Dan-Flores merged commit 6f906f4 into pytorch:main Sep 23, 2025
47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants