diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..d23d76b7 --- /dev/null +++ b/.clang-format @@ -0,0 +1,114 @@ +# The primary clang-format config file. +# TODO(afuller): Set these settings when they aren't broken: +# - AllowShortBlocksOnASingleLine: Empty +--- +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveMacros: false +AlignConsecutiveAssignments: false +AlignConsecutiveBitFields: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Left +AlignOperands: DontAlign +AlignTrailingComments: false +AllowAllArgumentsOnNextLine: true +AllowAllConstructorInitializersOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortEnumsOnASingleLine: true +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortLambdasOnASingleLine: All +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: false +BinPackParameters: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakInheritanceList: BeforeColon +BreakBeforeTernaryOperators: true +BreakConstructorInitializers: BeforeColon +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DeriveLineEnding: true +DerivePointerAlignment: false +DisableFormat: false +FixNamespaceComments: true +ForEachMacros: + - FOR_EACH + - FOR_EACH_R + - FOR_EACH_RANGE +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentCaseBlocks: false +IndentGotoLabels: true +IndentPPDirectives: None +IndentExternBlock: AfterExternBlock +IndentWidth: 2 +IndentWrappedFunctionNames: false +InsertTrailingCommas: None +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 2 +ObjCBreakBeforeNestedBlockParam: true +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyBlock: false +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInConditionalStatement: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +SpaceBeforeSquareBrackets: false +Standard: Latest +TabWidth: 8 +UseCRLF: false +UseTab: Never +... diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 7a05fab8..01ef8cea 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -10,7 +10,7 @@ jobs: build-docs: if: github.repository_owner == 'meta-pytorch' name: Build Documentation - runs-on: linux.g5.12xlarge.nvidia.gpu + runs-on: linux.12xlarge container: image: nvidia/cuda:12.8.1-devel-ubuntu24.04 timeout-minutes: 30 @@ -72,7 +72,7 @@ jobs: path: docs/build/html/ upload: - runs-on: linux.g5.4xlarge.nvidia.gpu + runs-on: linux.12xlarge permissions: # Grant write permission here so that the doc can be pushed to gh-pages branch contents: write diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 00000000..a3229411 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,77 @@ +name: Lint + +on: + push: + branches: + - main + pull_request: + +jobs: + lint: + if: github.repository_owner == 'meta-pytorch' + name: lintrunner + runs-on: linux.12xlarge + container: + image: nvidia/cuda:12.8.1-devel-ubuntu24.04 + timeout-minutes: 30 + steps: + - name: Setup git + shell: bash -l {0} + run: | + set -eux + + apt-get update + apt-get install -y git + + # git doesn't like mixed ownership, override it + chown -R root:root . + - name: Checkout + uses: actions/checkout@v4 + - name: Setup conda env + uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + miniconda-version: "latest" + activate-environment: test + python-version: '3.14' + auto-activate: false + - name: Verify conda environment + shell: bash -l {0} + run: | + conda info + which python + which conda + - name: Update pip + shell: bash -l {0} + run: python -m pip install --upgrade pip + - name: Install pytorch + shell: bash -l {0} + run: pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 + - name: Install Dependencies + shell: bash -l {0} + run: | + set -eux + + conda install -y git + conda install -y -c conda-forge glog==0.4.0 gflags fmt + pip install cmake + export USE_NCCL=0 + export USE_NCCLX=0 + export USE_GLOO=0 + export USE_SYSTEM_LIBS=1 + pip install --no-build-isolation .[dev] -v + - name: Install Lint Dependencies + shell: bash -l {0} + run: | + set -eux + + lintrunner init + - name: Lint + shell: bash -l {0} + run: | + set -eux + + # for lintrunner debugging + export RUST_BACKTRACE=1 + + lintrunner --force-color --all-files -v diff --git a/.gitignore b/.gitignore index 9063c02c..75ad09a8 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ third-party/ __pycache__/ *.so dist/ +.lintbin/ +.pyre/ diff --git a/.lintrunner.toml b/.lintrunner.toml new file mode 100644 index 00000000..0c832c7d --- /dev/null +++ b/.lintrunner.toml @@ -0,0 +1,39 @@ +[[linter]] +code = "CLANGFORMAT" +include_patterns = [ + "comms/torchcomms/**/*.hpp", + "comms/torchcomms/**/*.cpp", +] +exclude_patterns = [] +init_command = [ + "python3", + "tools/linter/adapters/pip_init.py", + "--dry-run={{DRYRUN}}", + "clang-format==21.1.2", +] +command = [ + "python3", + "tools/linter/adapters/clangformat_linter.py", + "--", + "@{{PATHSFILE}}", +] +is_formatter = true + +[[linter]] +code = 'PYRE' +include_patterns = [ + "comms/torchcomms/**/*.py", + "comms/torchcomms/**/*.pyi", +] +command = [ + 'python3', + 'tools/linter/adapters/pyre_linter.py', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'bash', + 'scripts/setup_pyre.sh', + '--dry-run={{DRYRUN}}', +] +is_formatter = false diff --git a/.pyre_configuration b/.pyre_configuration new file mode 100644 index 00000000..1a109fda --- /dev/null +++ b/.pyre_configuration @@ -0,0 +1,20 @@ +{ + "exclude": [ + ".*/build/.*", + ".*/docs/.*", + ".*/setup.py", + ".*/third-party/.*", + ".*/comms/rcclx/.*", + ".*/comms/ncclx/.*", + ".*/comms/utils/.*" + ], + "ignore_all_errors": [ + ], + "site_package_search_strategy": "all", + "source_directories": [ + "scripts", + "comms" + ], + "strict": false, + "version": "0.0.101749035478" +} diff --git a/comms/torchcomms/_comms_gloo.pyi b/comms/torchcomms/_comms_gloo.pyi new file mode 120000 index 00000000..4683c8b8 --- /dev/null +++ b/comms/torchcomms/_comms_gloo.pyi @@ -0,0 +1 @@ +gloo/_comms_gloo.pyi \ No newline at end of file diff --git a/comms/torchcomms/_comms_nccl.pyi b/comms/torchcomms/_comms_nccl.pyi new file mode 120000 index 00000000..71335c27 --- /dev/null +++ b/comms/torchcomms/_comms_nccl.pyi @@ -0,0 +1 @@ +nccl/_comms_nccl.pyi \ No newline at end of file diff --git a/comms/torchcomms/_comms_ncclx.pyi b/comms/torchcomms/_comms_ncclx.pyi new file mode 120000 index 00000000..d84682d2 --- /dev/null +++ b/comms/torchcomms/_comms_ncclx.pyi @@ -0,0 +1 @@ +ncclx/_comms_ncclx.pyi \ No newline at end of file diff --git a/comms/torchcomms/_comms_rccl.pyi b/comms/torchcomms/_comms_rccl.pyi new file mode 120000 index 00000000..23fb52e6 --- /dev/null +++ b/comms/torchcomms/_comms_rccl.pyi @@ -0,0 +1 @@ +rccl/_comms_rccl.pyi \ No newline at end of file diff --git a/comms/torchcomms/_comms_rcclx.pyi b/comms/torchcomms/_comms_rcclx.pyi new file mode 120000 index 00000000..50b90f4b --- /dev/null +++ b/comms/torchcomms/_comms_rcclx.pyi @@ -0,0 +1 @@ +rcclx/_comms_rcclx.pyi \ No newline at end of file diff --git a/scripts/setup_pyre.sh b/scripts/setup_pyre.sh new file mode 100644 index 00000000..00797fea --- /dev/null +++ b/scripts/setup_pyre.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +set -ex + +VERSION=$(grep "version" .pyre_configuration | sed -n -e 's/.*\(0\.0\.[0-9]*\).*/\1/p') +pip install pyre-check-nightly==$VERSION diff --git a/setup.py b/setup.py index a4e6842e..cddae3c9 100644 --- a/setup.py +++ b/setup.py @@ -124,6 +124,7 @@ def build_cmake(self, ext): "pytest", "numpy", "psutil", + "lintrunner", ], } diff --git a/tools/linter/adapters/clangformat_linter.py b/tools/linter/adapters/clangformat_linter.py new file mode 100644 index 00000000..7d71678b --- /dev/null +++ b/tools/linter/adapters/clangformat_linter.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +import argparse +import concurrent.futures +import json +import logging +import os +import subprocess +import sys +import time +from enum import Enum +from pathlib import Path +from typing import NamedTuple + + +IS_WINDOWS: bool = os.name == "nt" + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: str | None + line: int | None + char: int | None + code: str + severity: LintSeverity + name: str + original: str | None + replacement: str | None + description: str | None + + +def as_posix(name: str) -> str: + return name.replace("\\", "/") if IS_WINDOWS else name + + +def _run_command( + args: list[str], + *, + timeout: int, +) -> subprocess.CompletedProcess[bytes]: + logging.debug("$ %s", " ".join(args)) + start_time = time.monotonic() + try: + return subprocess.run( + args, + capture_output=True, + shell=IS_WINDOWS, # So batch scripts are found. + timeout=timeout, + check=True, + ) + finally: + end_time = time.monotonic() + logging.debug("took %dms", (end_time - start_time) * 1000) + + +def run_command( + args: list[str], + *, + retries: int, + timeout: int, +) -> subprocess.CompletedProcess[bytes]: + remaining_retries = retries + while True: + try: + return _run_command(args, timeout=timeout) + except subprocess.TimeoutExpired as err: + if remaining_retries == 0: + raise err + remaining_retries -= 1 + logging.warning( # noqa: G200 + "(%s/%s) Retrying because command failed with: %r", + retries - remaining_retries, + retries, + err, + ) + time.sleep(1) + + +def check_file( + filename: str, + binary: str, + retries: int, + timeout: int, +) -> list[LintMessage]: + try: + with open(filename, "rb") as f: + original = f.read() + proc = run_command( + [binary, filename], + retries=retries, + timeout=timeout, + ) + except subprocess.TimeoutExpired: + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="CLANGFORMAT", + severity=LintSeverity.ERROR, + name="timeout", + original=None, + replacement=None, + description=( + "clang-format timed out while trying to process a file. " + "Please report an issue in pytorch/pytorch with the " + "label 'module: lint'" + ), + ) + ] + except (OSError, subprocess.CalledProcessError) as err: + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="CLANGFORMAT", + severity=LintSeverity.ADVICE, + name="command-failed", + original=None, + replacement=None, + description=( + f"Failed due to {err.__class__.__name__}:\n{err}" + if not isinstance(err, subprocess.CalledProcessError) + else ( + "COMMAND (exit code {returncode})\n" + "{command}\n\n" + "STDERR\n{stderr}\n\n" + "STDOUT\n{stdout}" + ).format( + returncode=err.returncode, + command=" ".join(as_posix(x) for x in err.cmd), + stderr=err.stderr.decode("utf-8").strip() or "(empty)", + stdout=err.stdout.decode("utf-8").strip() or "(empty)", + ) + ), + ) + ] + + replacement = proc.stdout + if original == replacement: + return [] + + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="CLANGFORMAT", + severity=LintSeverity.WARNING, + name="format", + original=original.decode("utf-8"), + replacement=replacement.decode("utf-8"), + description="See https://clang.llvm.org/docs/ClangFormat.html.\nRun `lintrunner -a` to apply this patch.", + ) + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Format files with clang-format.", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--binary", + type=str, + help="clang-format binary path", + default="clang-format", + ) + parser.add_argument( + "--retries", + default=3, + type=int, + help="times to retry timed out clang-format", + ) + parser.add_argument( + "--timeout", + default=90, + type=int, + help="seconds to wait for clang-format", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=( + logging.NOTSET + if args.verbose + else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO + ), + stream=sys.stderr, + ) + + binary = args.binary + + with concurrent.futures.ThreadPoolExecutor( + max_workers=os.cpu_count(), + thread_name_prefix="Thread", + ) as executor: + futures = { + executor.submit(check_file, x, binary, args.retries, args.timeout): x + for x in args.filenames + } + for future in concurrent.futures.as_completed(futures): + try: + for lint_message in future.result(): + print(json.dumps(lint_message._asdict()), flush=True) + except Exception: + logging.critical('Failed at "%s".', futures[future]) + raise + + +if __name__ == "__main__": + main() diff --git a/tools/linter/adapters/pip_init.py b/tools/linter/adapters/pip_init.py new file mode 100644 index 00000000..8e5aca4f --- /dev/null +++ b/tools/linter/adapters/pip_init.py @@ -0,0 +1,105 @@ +""" +Initializer script that installs stuff to pip. +""" + +from __future__ import annotations + +import argparse +import logging +import os +import shutil +import subprocess +import sys +import time + + +def run_command( + args: list[str], + env: dict[str, str] | None = None, +) -> subprocess.CompletedProcess[str]: + logging.debug("$ %s", " ".join(args)) + start_time = time.monotonic() + try: + return subprocess.run(args, env=env, text=True, encoding="utf-8", check=True) + finally: + end_time = time.monotonic() + logging.debug("took %dms", (end_time - start_time) * 1000) + + +def main() -> None: + parser = argparse.ArgumentParser(description="pip initializer") + parser.add_argument( + "packages", + nargs="+", + help="pip packages to install", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "--dry-run", help="do not install anything, just print what would be done." + ) + + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=logging.NOTSET if args.verbose else logging.DEBUG, + stream=sys.stderr, + ) + + env: dict[str, str] = { + **os.environ, + "UV_PYTHON": sys.executable, + "UV_PYTHON_DOWNLOADS": "never", + "FORCE_COLOR": "1", + "CLICOLOR_FORCE": "1", + } + uv_index = env.get("UV_INDEX", env.get("PIP_EXTRA_INDEX_URL")) + if uv_index: + env["UV_INDEX"] = uv_index + + # If we are in a global install, use `--user` to install so that you do not + # need root access in order to initialize linters. + # + # However, `pip install --user` interacts poorly with virtualenvs (see: + # https://bit.ly/3vD4kvl) and conda (see: https://bit.ly/3KG7ZfU). So in + # these cases perform a regular installation. + in_conda = env.get("CONDA_PREFIX") is not None + in_virtualenv = env.get("VIRTUAL_ENV") is not None + need_user_flag = not in_conda and not in_virtualenv + + uv: str | None = shutil.which("uv") + is_uv_managed_python = "uv/python" in sys.base_prefix.replace("\\", "/") + if uv and (is_uv_managed_python or not need_user_flag): + pip_args = [uv, "pip", "install"] + elif sys.executable: + pip_args = [sys.executable, "-mpip", "install"] + else: + pip_args = ["pip3", "install"] + + if need_user_flag: + pip_args.append("--user") + + pip_args.extend(args.packages) + + for package in args.packages: + package_name, _, version = package.partition("=") + if version == "": + raise RuntimeError( + f"Package {package_name} did not have a version specified. " + "Please specify a version to produce a consistent linting experience." + ) + + dry_run = args.dry_run == "1" + if dry_run: + print(f"Would have run: {pip_args}") + sys.exit(0) + + run_command(pip_args, env=env) + + +if __name__ == "__main__": + main() diff --git a/tools/linter/adapters/pyre_linter.py b/tools/linter/adapters/pyre_linter.py new file mode 100644 index 00000000..d5c9ad84 --- /dev/null +++ b/tools/linter/adapters/pyre_linter.py @@ -0,0 +1,124 @@ +import argparse +import concurrent.futures +import json +import logging +import os +import subprocess +import sys +from enum import Enum +from pathlib import Path +from typing import Any, List, NamedTuple, Optional, Set, TypedDict + +logger: logging.Logger = logging.getLogger(__name__) + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: Optional[str] + line: Optional[int] + char: Optional[int] + code: str + severity: LintSeverity + name: str + original: Optional[str] + replacement: Optional[str] + description: Optional[str] + + +class PyreResult(TypedDict): + line: int + column: int + stop_line: int + stop_column: int + path: str + code: int + name: str + description: str + concise_description: str + + +def run_pyre() -> List[PyreResult]: + proc = subprocess.run( + ["pyre", "--output=json", "incremental"], + capture_output=True, + ) + return json.loads(proc.stdout) + + +def check_pyre( + filenames: Set[str], +) -> List[LintMessage]: + try: + results = run_pyre() + + return [ + LintMessage( + path=result["path"], + line=result["line"], + char=result["column"], + code="pyre", + severity=LintSeverity.WARNING, + name=result["name"], + description=result["description"], + original=None, + replacement=None, + ) + for result in results + ] + except Exception as err: + return [ + LintMessage( + path=None, + line=None, + char=None, + code="pyre", + severity=LintSeverity.ADVICE, + name="command-failed", + original=None, + replacement=None, + description=(f"Failed due to {err.__class__.__name__}:\n{err}"), + ) + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Checks files with pyre", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + args = parser.parse_args() + + logging.basicConfig( + format="<%(processName)s:%(levelname)s> %(message)s", + level=( + logging.NOTSET + if args.verbose + else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO + ), + stream=sys.stderr, + ) + + lint_messages = check_pyre(set(args.filenames)) + + for lint_message in lint_messages: + print(json.dumps(lint_message._asdict()), flush=True) + + +if __name__ == "__main__": + main()