Skip to content

Commit 3d69189

Browse files
authored
Update linter/formatter (#576)
1 parent 791af13 commit 3d69189

File tree

31 files changed

+113
-113
lines changed

31 files changed

+113
-113
lines changed

.flake8

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ max-line-length = 120
88
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
99
# E731 allow usage of assigning lambda expressions
1010
# N803,N806 allow caps and mixed case in function params. This is to work with Triton kernel coding style.
11+
# E704 ignored to allow black's formatting of Protocol stub methods (def method(self) -> None: ...)
1112
ignore =
12-
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806
13+
E203,E305,E402,E501,E704,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806
1314
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
1415
# to line this up with executable bit
1516
EXE001,

.meta/mast/hydrate_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
python .meta/mast/hydrate_cache.py --model-id Qwen/Qwen3-32B
1515
1616
"""
17+
1718
import argparse
1819
import os
1920
import sys

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ repos:
3939
hooks:
4040
- id: ufmt
4141
additional_dependencies:
42-
- black == 22.12.0
43-
- usort == 1.0.5
42+
- black == 24.4.2
43+
- usort == 1.0.8.post1
4444

4545
- repo: https://github.com/jsh9/pydoclint
4646
rev: 0.5.12

apps/grpo/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def simple_grpo_loss(
140140

141141
@dataclass
142142
class RewardActor(ForgeActor):
143-
144143
reward_functions: list[Callable]
145144

146145
@endpoint

apps/sft/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,6 @@ def __repr__(self) -> str:
287287

288288

289289
async def run(cfg: DictConfig) -> None:
290-
291290
logging.info("Spawning recipe...")
292291
process_cfg = cfg.pop("processes")
293292

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,9 @@ prerelease = "allow"
100100
environments = [
101101
"sys_platform == 'linux'",
102102
]
103-
# override-dependencies = ["torch>2.7.1", "torchaudio>=2.7.1", "torchvision>=0.22.0"]
103+
104+
[tool.black]
105+
target-version = ["py310"] # match the minium supported python version
106+
107+
[tool.usort]
108+
first_party_detection = false

src/forge/actors/coder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
import tempfile
1111
from pathlib import Path
1212

13-
from monarch.actor import endpoint
14-
1513
from forge.controller import ForgeActor
1614

15+
from monarch.actor import endpoint
16+
1717
logger = logging.getLogger(__name__)
1818
logger.setLevel(logging.DEBUG)
1919

src/forge/actors/generator.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,28 @@
1717

1818
import torch
1919
import torchstore as ts
20+
21+
from forge.actors._torchstore_utils import (
22+
extract_param_name,
23+
get_dcp_whole_state_dict_key,
24+
get_param_key,
25+
get_param_prefix,
26+
load_tensor_from_dcp,
27+
rdma_available,
28+
)
29+
30+
from forge.controller import (
31+
ForgeActor,
32+
get_proc_mesh,
33+
host_mesh_from_proc,
34+
stop_proc_mesh,
35+
)
36+
from forge.data_models.completion import Completion
37+
from forge.data_models.prompt import to_prompt
38+
from forge.observability.metrics import record_metric, Reduce
39+
from forge.observability.perf_tracker import Tracer
40+
from forge.types import ProcessConfig
41+
from forge.util._shared_tensor import SharedTensor, SharedTensorHandle
2042
from monarch.actor import current_rank, endpoint, ProcMesh, this_host
2143

2244
from vllm.config import VllmConfig
@@ -42,28 +64,6 @@
4264
from vllm.v1.structured_output import StructuredOutputManager
4365
from vllm.worker.worker_base import WorkerWrapperBase
4466

45-
from forge.actors._torchstore_utils import (
46-
extract_param_name,
47-
get_dcp_whole_state_dict_key,
48-
get_param_key,
49-
get_param_prefix,
50-
load_tensor_from_dcp,
51-
rdma_available,
52-
)
53-
54-
from forge.controller import (
55-
ForgeActor,
56-
get_proc_mesh,
57-
host_mesh_from_proc,
58-
stop_proc_mesh,
59-
)
60-
from forge.data_models.completion import Completion
61-
from forge.data_models.prompt import to_prompt
62-
from forge.observability.metrics import record_metric, Reduce
63-
from forge.observability.perf_tracker import Tracer
64-
from forge.types import ProcessConfig
65-
from forge.util._shared_tensor import SharedTensor, SharedTensorHandle
66-
6767
logger = logging.getLogger(__name__)
6868
logger.setLevel(logging.INFO)
6969

src/forge/actors/reference_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from dataclasses import dataclass, field, fields
1414

1515
import torch
16+
17+
from forge.controller import ForgeActor
18+
from forge.observability.metrics import record_metric, Reduce
19+
from forge.observability.perf_tracker import Tracer
20+
from forge.util.ops import compute_logprobs
1621
from monarch.actor import current_rank, current_size, endpoint
1722
from torch.distributed.tensor import DTensor
1823

@@ -27,11 +32,6 @@
2732
from torchtitan.experiments.forge.engine import ForgeEngine
2833
from torchtitan.experiments.forge.job_config import ForgeJobConfig
2934

30-
from forge.controller import ForgeActor
31-
from forge.observability.metrics import record_metric, Reduce
32-
from forge.observability.perf_tracker import Tracer
33-
from forge.util.ops import compute_logprobs
34-
3535
logger = logging.getLogger(__name__)
3636
logger.setLevel(logging.INFO)
3737

src/forge/actors/replay_buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
from operator import itemgetter
1212
from typing import Any, Callable
1313

14-
from monarch.actor import endpoint
15-
1614
from forge.controller import ForgeActor
1715
from forge.observability.metrics import record_metric, Reduce
1816
from forge.observability.perf_tracker import trace
1917

18+
from monarch.actor import endpoint
19+
2020
logger = logging.getLogger(__name__)
2121
logger.setLevel(logging.INFO)
2222

0 commit comments

Comments
 (0)