Skip to content

refactor(ops): Split nn.py to match native nn/ structure #145

@m96-chan

Description

@m96-chan

Problem

src/pygpukit/ops/nn.py is 885 lines and should match the native nn/ structure from #133.

Current State

src/pygpukit/ops/
└── nn.py  (885 lines - all NN ops)

Proposed Structure (aligned with #133)

src/pygpukit/ops/nn/
├── __init__.py          (exports)
├── activation.py        (gelu, silu, relu, sigmoid, tanh)
├── norm.py              (layernorm, rmsnorm)
├── attention.py         (sdpa_causal, sdpa_causal_fixed_cache)
└── rope.py              (rope_inplace, rope_inplace_f32table)

Benefits

  • Mirrors native/ops/nn/ structure
  • Each category < 200 lines
  • Easier to maintain Python/C++ correspondence

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions