Skip to content

Add vmap for BroadcastAxes#3344

Open
angeloskath wants to merge 1 commit intomainfrom
baxes-vmap
Open

Add vmap for BroadcastAxes#3344
angeloskath wants to merge 1 commit intomainfrom
baxes-vmap

Conversation

@angeloskath
Copy link
Copy Markdown
Member

A simplified and correct version of #3319. The main point is that it moves the batch axis first and takes advantage of the BroadcastAxes primitive's negative axes handling (which was made for this tbh) to avoid all axis math. Unfortunately we may need to expand to move the vectorized axis first if the dims are smaller.

In addition to the above it makes sure that the BroadcastAxes primitive is used even when it is a noop to have a consistent graph that works all the time. Basically dealing with the broadcast version of #3202.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant