diff --git a/.config/pyproject_template/settings.toml b/.config/pyproject_template/settings.toml index b0f05661..cb167440 100644 --- a/.config/pyproject_template/settings.toml +++ b/.config/pyproject_template/settings.toml @@ -6,7 +6,7 @@ description = "A MUD proxy with plugin support for Python 3.12+" author_name = "Bast" author_email = "bast@bastproxy.com" github_user = "endavis" -github_repo = "bastproxy-py3" +github_repo = "bastproxy" [template] commit = "2c1171b97183a76fe8415d408432bf344081507c" diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index f30dc64a..00000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,45 +0,0 @@ ---- -name: Bug report -about: Create a report to help us improve -title: '[BUG] ' -labels: bug -assignees: '' ---- - -## Bug Description - -A clear and concise description of what the bug is. - -## Steps to Reproduce - -Steps to reproduce the behavior: - -1. Go to '...' -2. Click on '....' -3. Scroll down to '....' -4. See error - -## Expected Behavior - -A clear and concise description of what you expected to happen. - -## Actual Behavior - -A clear and concise description of what actually happened. - -## Environment - -- OS: [e.g. Ubuntu 22.04, macOS 14.0, Windows 11] -- Python version: [e.g. 3.12.0] -- Package version: [e.g. 1.0.0] - -## Additional Context - -Add any other context about the problem here. Include: -- Error messages or stack traces -- Screenshots if applicable -- Related issues or pull requests - -## Possible Solution - -If you have suggestions on how to fix the bug, please describe them here. diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 00000000..ce402075 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,80 @@ +name: 🐛 Bug Report +description: Report a bug or unexpected behavior +labels: ["bug", "needs-triage"] +body: + - type: markdown + attributes: + value: | + ## Report a Bug + + Help us fix the issue by providing clear reproduction steps. + + - type: textarea + id: bug-description + attributes: + label: Bug Description + description: A clear and concise description of what the bug is + placeholder: | + Describe the bug and its impact. + Example: "Application crashes when processing files larger than 10MB..." + validations: + required: true + + - type: textarea + id: steps-to-reproduce + attributes: + label: Steps to Reproduce + description: Step-by-step instructions to reproduce the issue + placeholder: | + 1. Run command `...` + 2. With input `...` + 3. Observe error `...` + validations: + required: true + + - type: textarea + id: expected-vs-actual + attributes: + label: Expected vs Actual Behavior + description: What should happen vs what actually happens + placeholder: | + **Expected:** The function should return a valid result + **Actual:** The function raises ValueError + validations: + required: true + + - type: textarea + id: environment + attributes: + label: Environment + description: System information (Python version, OS, package version) + placeholder: | + - Python: 3.12 + - OS: Ubuntu 22.04 + - Package version: 1.2.3 + render: markdown + validations: + required: false + + - type: textarea + id: error-output + attributes: + label: Error Output + description: Paste any error messages, stack traces, or relevant logs + placeholder: Paste error output here... + render: shell + validations: + required: false + + - type: textarea + id: additional-context + attributes: + label: Additional Context + description: Any other relevant information + placeholder: | + - Screenshots + - Related issues + - Workarounds attempted + - Possible solutions + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/chore.yml b/.github/ISSUE_TEMPLATE/chore.yml new file mode 100644 index 00000000..3bb18667 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/chore.yml @@ -0,0 +1,72 @@ +name: 🔨 Chore / Maintenance +description: Maintenance tasks, tooling, CI/CD, or dependency updates +labels: ["chore", "needs-triage"] +body: + - type: markdown + attributes: + value: | + ## Chore / Maintenance Task + + Describe the maintenance or tooling task that needs to be done. + + - type: dropdown + id: chore-type + attributes: + label: Chore Type + description: What kind of maintenance task is this? + options: + - CI/CD configuration + - Dependency updates + - Tooling improvements + - Code cleanup + - Configuration changes + - Other + validations: + required: true + + - type: textarea + id: description + attributes: + label: Description + description: Describe the maintenance task + placeholder: | + What needs to be done and why? + Example: "Update GitHub Actions to use Node 20 before Node 16 EOL..." + validations: + required: true + + - type: textarea + id: proposed-changes + attributes: + label: Proposed Changes + description: What specific changes need to be made? + placeholder: | + - Update file X + - Modify configuration Y + - Add/remove dependency Z + validations: + required: false + + - type: textarea + id: success-criteria + attributes: + label: Success Criteria + description: How will we know this task is complete? + placeholder: | + - [ ] CI passes + - [ ] No breaking changes + - [ ] Documentation updated if needed + validations: + required: false + + - type: textarea + id: additional-context + attributes: + label: Additional Context + description: Any other relevant information + placeholder: | + - Related issues + - Urgency/timeline + - Dependencies on other tasks + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000..bb48500d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: 💬 Discussions + url: https://github.com/{owner}/{repo}/discussions + about: For questions, ideas, and general discussions + - name: 📖 Documentation + url: https://github.com/{owner}/{repo}/blob/main/README.md + about: Check the documentation for usage guides and examples diff --git a/.github/ISSUE_TEMPLATE/documentation.yml b/.github/ISSUE_TEMPLATE/documentation.yml new file mode 100644 index 00000000..3d33be00 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation.yml @@ -0,0 +1,74 @@ +name: 📚 Documentation Request +description: Request new documentation or improvements to existing docs +labels: ["documentation", "needs-triage"] +body: + - type: markdown + attributes: + value: | + ## Documentation Request + + Help us improve our documentation by describing what's needed. + + - type: dropdown + id: doc-type + attributes: + label: Documentation Type + description: What kind of documentation change is needed? + options: + - New guide or tutorial + - Update existing documentation + - Fix incorrect information + - Add code examples + - API documentation + - Other + validations: + required: true + + - type: textarea + id: description + attributes: + label: Description + description: Describe the documentation that is needed + placeholder: | + What documentation is missing or needs improvement? + Example: "There's no guide explaining how to configure logging..." + validations: + required: true + + - type: textarea + id: location + attributes: + label: Suggested Location + description: Where should this documentation live? + placeholder: | + - docs/getting-started/ + - docs/examples/ + - README.md + - Inline code comments + validations: + required: false + + - type: textarea + id: success-criteria + attributes: + label: Success Criteria + description: How will we know the documentation is complete? + placeholder: | + - [ ] Topic is fully explained + - [ ] Code examples included + - [ ] Added to navigation/index + - [ ] Reviewed for accuracy + validations: + required: false + + - type: textarea + id: additional-context + attributes: + label: Additional Context + description: Any other relevant information + placeholder: | + - Links to related documentation + - Examples from other projects + - Target audience + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index ef8e9e6b..00000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,44 +0,0 @@ ---- -name: Feature request -about: Suggest an idea for this project -title: '[FEATURE] ' -labels: enhancement -assignees: '' ---- - -## Problem Statement - -Is your feature request related to a problem? Please describe. -A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] - -## Proposed Solution - -A clear and concise description of what you want to happen. - -## Alternatives Considered - -A clear and concise description of any alternative solutions or features you've considered. - -## Use Cases - -Describe specific use cases where this feature would be beneficial: - -1. Use case 1 -2. Use case 2 -3. Use case 3 - -## Implementation Ideas - -If you have ideas about how this could be implemented, please share them here. - -## Additional Context - -Add any other context, screenshots, or examples about the feature request here. - -## Benefits - -What are the main benefits of implementing this feature? - -- Benefit 1 -- Benefit 2 -- Benefit 3 diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 00000000..b601f454 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,58 @@ +name: ✨ Feature Request +description: Propose a new feature or enhancement +labels: ["enhancement", "needs-triage"] +body: + - type: markdown + attributes: + value: | + ## Propose a New Feature + + Describe the problem you're trying to solve and your proposed solution. + + - type: textarea + id: problem + attributes: + label: Problem + description: What problem does this feature solve? What need does it address? + placeholder: | + Describe the problem or limitation you're experiencing. + Example: "Currently, there's no way to validate custom data types..." + validations: + required: true + + - type: textarea + id: proposed-solution + attributes: + label: Proposed Solution + description: Describe how you envision this feature working + placeholder: | + Clear description of what you want to happen. + Include specific implementation ideas if you have them. + validations: + required: true + + - type: textarea + id: success-criteria + attributes: + label: Success Criteria + description: How will we know this feature is complete and working correctly? + placeholder: | + - [ ] Feature implements X functionality + - [ ] Tests added and passing + - [ ] Documentation updated + - [ ] Works with existing code without breaking changes + validations: + required: false + + - type: textarea + id: additional-context + attributes: + label: Additional Context + description: Add any other context, examples, links, or screenshots + placeholder: | + - Similar features in other projects + - Code examples + - Use cases + - Breaking change considerations + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/refactor.yml b/.github/ISSUE_TEMPLATE/refactor.yml new file mode 100644 index 00000000..8a6b44df --- /dev/null +++ b/.github/ISSUE_TEMPLATE/refactor.yml @@ -0,0 +1,60 @@ +name: 🔧 Refactor Request +description: Propose code refactoring or improvement +labels: ["refactor", "needs-triage"] +body: + - type: markdown + attributes: + value: | + ## Propose a Refactoring + + Describe the current code issue and how you propose to improve it. + + - type: textarea + id: current-code-issue + attributes: + label: Current Code Issue + description: Describe the code that needs refactoring and why + placeholder: | + Describe what code currently exists, where it's located, and what problems it causes. + Example: "The validation logic in src/module.py is duplicated across 5 functions, + making it difficult to maintain and test..." + validations: + required: true + + - type: textarea + id: proposed-improvement + attributes: + label: Proposed Improvement + description: Describe how you propose to refactor the code + placeholder: | + Describe the refactoring approach and what the code will look like after. + Include specific files/modules that will be affected. + validations: + required: true + + - type: textarea + id: success-criteria + attributes: + label: Success Criteria + description: How will we know the refactoring is complete and successful? + placeholder: | + - [ ] Code duplication eliminated + - [ ] All existing tests still pass + - [ ] New tests added for refactored code + - [ ] No breaking changes to public API + - [ ] Code is more maintainable/readable + validations: + required: false + + - type: textarea + id: additional-context + attributes: + label: Additional Context + description: Any other relevant information + placeholder: | + - Performance impact + - Breaking change considerations + - Migration steps needed + - Related issues or PRs + validations: + required: false diff --git a/.github/python-versions.json b/.github/python-versions.json index 0baee70f..aacf89c2 100644 --- a/.github/python-versions.json +++ b/.github/python-versions.json @@ -1,4 +1,4 @@ { "oldest": "3.12", - "newest": "3.13" + "newest": "3.14" } diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f5b7e661..b9e7ee3a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,14 +29,12 @@ repos: entry: uv run ruff check --fix language: system types: [python] - exclude: ^tools/pyproject_template/ - id: ruff-format name: ruff-format entry: uv run ruff format language: system types: [python] - exclude: ^tools/pyproject_template/ # Type checking (uses project's mypy version) # Only checks src/ to match doit type_check behavior @@ -54,7 +52,6 @@ repos: entry: bash -c 'command -v bandit &>/dev/null || uv run python -c "import bandit" 2>/dev/null || exit 0; uv run bandit -c pyproject.toml "$@"' -- language: system types: [python] - exclude: ^tools/pyproject_template/ # Spell checking (uses project's codespell version) - id: codespell @@ -71,7 +68,7 @@ repos: bash -c ' BRANCH=$(git branch --show-current); if [[ "$BRANCH" != "main" && "$BRANCH" != "develop" && ! "$BRANCH" =~ ^(issue|feat|fix|docs|test|refactor|chore|ci|perf|hotfix)/[0-9]+-[a-z0-9\-]+$ && ! "$BRANCH" =~ ^release/.+ ]]; then - echo "Branch name must follow convention:"; + echo "❌ Branch name must follow convention:"; echo " - issue/-"; echo " - feat/-"; echo " - fix/-"; @@ -91,6 +88,14 @@ repos: pass_filenames: false always_run: true + # Generate documentation TOC from frontmatter + - id: generate-doc-toc + name: Generate documentation TOC + entry: doit docs_toc + language: system + pass_filenames: false + files: ^docs/.*\.md$ + # Prevent direct commits to main branch - id: no-commit-to-main name: Prevent commits to main branch @@ -98,9 +103,9 @@ repos: bash -c ' BRANCH=$(git branch --show-current); if [[ "$BRANCH" == "main" ]]; then - echo "ERROR: Direct commits to main branch are not allowed!"; + echo "❌ ERROR: Direct commits to main branch are not allowed!"; echo ""; - echo "The mandatory workflow is: Issue -> Branch -> Commit -> PR -> Merge"; + echo "The mandatory workflow is: Issue → Branch → Commit → PR → Merge"; echo ""; echo "Please follow these steps:"; echo " 1. Ensure a GitHub issue exists for your change"; @@ -124,7 +129,7 @@ repos: # Find staged files with .local or .local. in name (as distinct segment, not substring) LOCAL_FILES=$(git diff --cached --name-only | grep -E "\.local$|\.local\." | grep -v "\.local\.example" || true) if [[ -n "$LOCAL_FILES" ]]; then - echo "ERROR: Local config files should not be committed!" + echo "❌ ERROR: Local config files should not be committed!" echo "" echo "The following files appear to be user-specific configs:" echo "$LOCAL_FILES" | sed "s/^/ - /" @@ -146,7 +151,7 @@ repos: bash -c ' if git diff --cached --name-only | grep -q "^pyproject.toml$"; then if git diff --cached pyproject.toml | grep -E "^[-+]dynamic\s*=" | grep -q .; then - echo "ERROR: Changes to the dynamic field in pyproject.toml are not allowed!" + echo "❌ ERROR: Changes to the dynamic field in pyproject.toml are not allowed!" echo "" echo "Version is managed dynamically via git tags." echo "The dynamic = [\"version\"] setting should not be modified." diff --git a/pyproject.toml b/pyproject.toml index 755a00ab..72a2c53d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,110 +120,62 @@ fallback-version = "0.0.0" version-file = "src/bastproxy/_version.py" [tool.ruff] -line-length = 88 +line-length = 100 target-version = "py312" extend-exclude = ["**/_version.py"] exclude = [ - ".bzr", - ".direnv", - ".eggs", - ".git", - ".git-rewrite", - ".hg", - ".mypy_cache", - ".nox", - ".pants.d", - ".pytype", - ".ruff_cache", - ".svn", - ".tox", - ".venv", - "__pypackages__", - "_build", - "buck-out", - "build", - "dist", - "node_modules", - "venv", "data", "evennia", - "tools/pyproject_template", ] [tool.ruff.lint] select = [ - "E", - "W", - "F", - "I", - "D", - "UP", - "B", - "C4", - "DTZ", - "T10", - "EM", - "ISC", - "ICN", - "G", - "PIE", - "T20", - "PT", - "Q", - "RET", - "SIM", - "TID", - "ARG", - "PTH", - "ERA", - "PL", - "TRY", - "RUF", + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort + "N", # pep8-naming + "W", # pycodestyle warnings + "UP", # pyupgrade + "ANN", # flake8-annotations (type hints) + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "RUF", # Ruff-specific rules + "SIM", # flake8-simplify + "ASYNC", # flake8-async ] ignore = [ - "D203", - "D213", - "E501", - "TRY003", - "PLR0913", - "PLR0912", - "PLR0915", - "PLR2004", - "PLR0911", - "ISC001", - "ERA001", - "B019", - "ARG001", - "ARG002", - "UP038", - "D205", - "D415", - "T201", - "PLC0415", + "ANN", # TODO: Enable after adding type annotations to codebase + "E501", # Line too long - existing code uses longer lines + "N", # TODO: Enable after fixing naming conventions + "B019", # lru_cache on methods - intentional for API caching + "SIM108", # TODO: Refactor if/else to ternary where appropriate ] -fixable = ["ALL"] -unfixable = [] -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" [tool.ruff.lint.per-file-ignores] -"tests/**/*.py" = ["D", "ARG", "PLR2004"] -"__init__.py" = ["F401", "D104", "E402"] -"src/bastproxy/__init__.py" = ["I001"] -"tools/doit/**/*.py" = ["PTH", "EM101", "EM102", "EM103"] - -[tool.ruff.lint.pydocstyle] -convention = "google" +"tools/pyproject_template/*.py" = [ + "ANN401", # Any return type for dynamic GitHub API responses + "RUF022", # __all__ intentionally grouped by category, not sorted +] +"tools/doit/*.py" = [ + "ANN401", # Any return type for dynamic doit task responses + "ANN001", # Missing type annotation for function argument + "ANN201", # Missing return type annotation for public function +] +"tests/**/*.py" = [ + "ANN", # No type annotations required in tests +] +"**/__init__.py" = [ + "F401", # Unused imports - intentional re-exports + "E402", # Module level import not at top - needed for path setup +] [tool.ruff.lint.isort] known-first-party = ["bastproxy"] -required-imports = [] - -[tool.ruff.lint.mccabe] -max-complexity = 15 [tool.ruff.format] -docstring-code-format = true -docstring-code-line-length = 88 +quote-style = "double" +indent-style = "space" +line-ending = "lf" [tool.mypy] mypy_path = ["src"] @@ -247,6 +199,7 @@ disallow_incomplete_defs = false [[tool.mypy.overrides]] module = [ + "doit.*", "dumper.*", "telnetlib3.*", "rapidfuzz.*", @@ -254,6 +207,7 @@ module = [ "psutil", "psutil._common", "pydatatracker", + "yaml", ] ignore_missing_imports = true diff --git a/src/bastproxy/__init__.py b/src/bastproxy/__init__.py index a0a5048d..f9c7da57 100644 --- a/src/bastproxy/__init__.py +++ b/src/bastproxy/__init__.py @@ -37,9 +37,9 @@ # Standard Library import datetime import logging +import os import sys from pathlib import Path -import os # The modules below are imported to add their functions to the API from bastproxy.libs import argp, timing @@ -158,9 +158,7 @@ def run(self, args: dict) -> None: # load plugins on startup plugin_loader.load_plugins_on_startup() - LogRecord( - "Plugin Manager - all plugins loaded", level="info", sources=["mudproxy"] - )() + LogRecord("Plugin Manager - all plugins loaded", level="info", sources=["mudproxy"])() # do any post plugin loaded actions self.post_plugins_loaded() @@ -170,9 +168,7 @@ def run(self, args: dict) -> None: self.api("plugins.core.events:add.event")( "ev_bastproxy_proxy_ready", "mudproxy", - description=[ - "An event raised when the proxy is ready to accept connections" - ], + description=["An event raised when the proxy is ready to accept connections"], arg_descriptions={"None": None}, ) @@ -210,9 +206,7 @@ def run(self, args: dict) -> None: Listeners().create_listeners() - LogRecord( - "__main__ - Launching async loop", level="info", sources=["mudproxy"] - )() + LogRecord("__main__ - Launching async loop", level="info", sources=["mudproxy"])() run_asynch() @@ -264,9 +258,7 @@ def main() -> None: default=-1, ) - parser.add_argument( - "-pf", "--profile", help="profile code", action="store_true", default=False - ) + parser.add_argument("-pf", "--profile", help="profile code", action="store_true", default=False) parser.add_argument( "--IPv4-address", @@ -292,9 +284,7 @@ def main() -> None: log_file = BASEAPI.BASEDATALOGPATH / "bastproxy.log" file_handler = logging.FileHandler(log_file, mode="a") file_handler.setFormatter( - logging.Formatter( - "%(asctime)s : %(levelname)-9s - %(name)-22s - %(message)s" - ) + logging.Formatter("%(asctime)s : %(levelname)-9s - %(name)-22s - %(message)s") ) logging.basicConfig( level="INFO", diff --git a/src/bastproxy/libs/api/_api.py b/src/bastproxy/libs/api/_api.py index e8f2d71d..ba5dd177 100644 --- a/src/bastproxy/libs/api/_api.py +++ b/src/bastproxy/libs/api/_api.py @@ -260,8 +260,7 @@ def _api_add_apis_for_object(self, toplevel, item) -> None: )() api_functions = self.get_api_functions_in_object(item) LogRecord( - f"_api_add_apis_for_object: {toplevel}:{item} has {len(api_functions)} " - "api functions", + f"_api_add_apis_for_object: {toplevel}:{item} has {len(api_functions)} api functions", level=self.log_level, sources=[__name__, toplevel], )() @@ -362,22 +361,14 @@ def add_events(self) -> None: self("plugins.core.events:add.event")( "ev_libs.api_character_active", APILOCATION, - description=( - "An event for when the character is active and ready for commands" - ), - arg_descriptions={ - "is_character_active": "The state of the is_character_active flag" - }, + description=("An event for when the character is active and ready for commands"), + arg_descriptions={"is_character_active": "The state of the is_character_active flag"}, ) self("plugins.core.events:add.event")( "ev_libs.api_character_inactive", APILOCATION, - description=( - "An event for when the character is inactive and not ready for commands" - ), - arg_descriptions={ - "is_character_active": "The state of the is_character_active flag" - }, + description=("An event for when the character is inactive and not ready for commands"), + arg_descriptions={"is_character_active": "The state of the is_character_active flag"}, ) def _api_is_character_active_get(self) -> bool: @@ -488,9 +479,7 @@ def add( """ full_api_name: str = f"{top_level_api}:{name}" - api_item = APIItem( - full_api_name, tfunction, self.owner_id, description=description - ) + api_item = APIItem(full_api_name, tfunction, self.owner_id, description=description) if instance: return self._api_instance(api_item, force) @@ -540,10 +529,7 @@ def _api_instance(self, api_item: APIItem, force: bool = False) -> bool: """ if api_item.full_api_name in self._instance_api: - if ( - api_item.tfunction - == self._instance_api[api_item.full_api_name].tfunction - ): + if api_item.tfunction == self._instance_api[api_item.full_api_name].tfunction: return True if force: @@ -561,9 +547,7 @@ def _api_instance(self, api_item: APIItem, force: bool = False) -> bool: sources=[__name__, api_item.owner_id], )() except ImportError: - print( - f"libs.api:instance - {api_item.full_api_name} already exists" - ) + print(f"libs.api:instance - {api_item.full_api_name} already exists") return False @@ -655,9 +639,7 @@ def _api_remove(self, top_level_api: str) -> None: self._class_api[i].tfunction.api["addedin"][top_level_api].remove(api_name) # type: ignore del self._class_api[i] - instance_keys = [ - item for item in self._instance_api if item.startswith(api_toplevel) - ] + instance_keys = [item for item in self._instance_api if item.startswith(api_toplevel)] LogRecord( f"libs.api:remove instance api - {instance_keys =}", level="debug", @@ -832,9 +814,7 @@ def _api_detail( tmsg.extend(api_instance.detail(show_function_code=show_function_code)) if stats_by_plugin or stats_by_caller: - tmsg.extend( - self.format_stats(api_location, stats_by_plugin, stats_by_caller) - ) + tmsg.extend(self.format_stats(api_location, stats_by_plugin, stats_by_caller)) else: tmsg.append(f"{api_location} is not in the api") @@ -899,9 +879,7 @@ def _stats_for_specific_caller( None """ - stats_keys = [ - k for k in api_data.stats.detailed_calls if k.startswith(stats_by_caller) - ] + stats_keys = [k for k in api_data.stats.detailed_calls if k.startswith(stats_by_caller)] stats_keys = sorted(stats_keys) stats_caller_data = [ {"caller": i, "count": api_data.stats.detailed_calls[i]} for i in stats_keys @@ -942,8 +920,7 @@ def _stats_overall(self, tmsg: list[str], api_data: APIItem) -> None: stats_keys = api_data.stats.calls_by_caller.keys() stats_keys = sorted(stats_keys) stats_caller_data = [ - {"caller": i, "count": api_data.stats.calls_by_caller[i]} - for i in stats_keys + {"caller": i, "count": api_data.stats.calls_by_caller[i]} for i in stats_keys ] stats_caller_columns = [ {"name": "Caller", "key": "caller", "width": 20}, @@ -977,9 +954,7 @@ def get_top_level_api_list(self, top_level_api: str) -> list[str]: None """ - api_data: list[str] = [ - i for i in self._class_api if i.startswith(top_level_api) - ] + api_data: list[str] = [i for i in self._class_api if i.startswith(top_level_api)] for i in self._instance_api: if i.startswith(top_level_api): api_data.append(i) diff --git a/src/bastproxy/libs/api/_apiitem.py b/src/bastproxy/libs/api/_apiitem.py index b9d34c0c..1d928b84 100644 --- a/src/bastproxy/libs/api/_apiitem.py +++ b/src/bastproxy/libs/api/_apiitem.py @@ -209,9 +209,7 @@ def detail(self, show_function_code: bool = False) -> list[str]: from ._api import API tmsg.append("") - tmsg.append( - f"function defined in {sourcefile.replace(str(API.BASEPATH), '')}" - ) + tmsg.append(f"function defined in {sourcefile.replace(str(API.BASEPATH), '')}") if show_function_code: tmsg.append("") diff --git a/src/bastproxy/libs/api/_functools.py b/src/bastproxy/libs/api/_functools.py index 38aac8aa..3b2f3b40 100644 --- a/src/bastproxy/libs/api/_functools.py +++ b/src/bastproxy/libs/api/_functools.py @@ -107,15 +107,9 @@ def get_caller_owner_id(ignore_owner_list: list[str] | None = None) -> str: if frame := inspect.currentframe(): while frame := frame.f_back: - if "self" in frame.f_locals and not isinstance( - frame.f_locals["self"], APIItem - ): + if "self" in frame.f_locals and not isinstance(frame.f_locals["self"], APIItem): tcs = frame.f_locals["self"] - if ( - hasattr(tcs, "owner_id") - and tcs.owner_id - and tcs.owner_id not in ignore_list - ): + if hasattr(tcs, "owner_id") and tcs.owner_id and tcs.owner_id not in ignore_list: caller_id = tcs.owner_id break if ( diff --git a/src/bastproxy/libs/argp.py b/src/bastproxy/libs/argp.py index 23c619fa..8363e2c3 100644 --- a/src/bastproxy/libs/argp.py +++ b/src/bastproxy/libs/argp.py @@ -128,9 +128,7 @@ def _get_help_string(self, action: argparse.Action) -> str | None: and action.default is not SUPPRESS ): defaulting_nargs = [OPTIONAL, ZERO_OR_MORE] - if ( - action.option_strings or action.nargs in defaulting_nargs - ) and action.default != "": + if (action.option_strings or action.nargs in defaulting_nargs) and action.default != "": temp_help += " (default: %(default)s)" return temp_help diff --git a/src/bastproxy/libs/asynch/__init__.py b/src/bastproxy/libs/asynch/__init__.py index 5a139ef1..e7ac1e48 100644 --- a/src/bastproxy/libs/asynch/__init__.py +++ b/src/bastproxy/libs/asynch/__init__.py @@ -97,9 +97,7 @@ class TaskItem: manage the task, check its completion status, and retrieve its result. """ - def __init__( - self, func: Awaitable | Callable, name: str, startstring: str = "" - ) -> None: + def __init__(self, func: Awaitable | Callable, name: str, startstring: str = "") -> None: """Initialize a TaskItem object. This constructor initializes the task item with the provided coroutine or @@ -193,12 +191,8 @@ def create( self.task = loop.create_task(self.coroutine, name=self.name) else: self.task = loop.create_task(self.coroutine) - self.task.add_done_callback( - functools.partial(_handle_task_result, message=message) - ) - LogRecord( - f"(Task) {self.name} : {self.task}", level="debug", sources=[__name__] - )() + self.task.add_done_callback(functools.partial(_handle_task_result, message=message)) + LogRecord(f"(Task) {self.name} : {self.task}", level="debug", sources=[__name__])() if self.startstring: LogRecord( f"(Task) {self.name} : Created - {self.startstring}", @@ -206,9 +200,7 @@ def create( sources=[__name__], )() else: - LogRecord( - f"(Task) {self.name} : Created", level="debug", sources=[__name__] - )() + LogRecord(f"(Task) {self.name} : Created", level="debug", sources=[__name__])() return self.task @@ -287,9 +279,7 @@ async def task_check_for_new_tasks(self) -> None: task.create() - LogRecord( - f"Tasks - {asyncio.all_tasks()}", level="debug", sources=[__name__] - )() + LogRecord(f"Tasks - {asyncio.all_tasks()}", level="debug", sources=[__name__])() await asyncio.sleep(0.1) @@ -329,25 +319,19 @@ async def shutdown(signal_: signal.Signals, loop_: asyncio.AbstractEventLoop) -> sources=["mudproxy"], )() for item in tasks: - LogRecord( - f"shutdown - {item.get_name()}", level="warning", sources=["mudproxy"] - )() + LogRecord(f"shutdown - {item.get_name()}", level="warning", sources=["mudproxy"])() [task.cancel() for task in tasks] exceptions = await asyncio.gather(*tasks, return_exceptions=True) - if new_exceptions := [ - exc for exc in exceptions if not isinstance(exc, asyncio.CancelledError) - ]: + if new_exceptions := [exc for exc in exceptions if not isinstance(exc, asyncio.CancelledError)]: LogRecord( f"shutdown - Tasks had Exceptions: {new_exceptions}", level="warning", sources=["mudproxy"], )() else: - LogRecord( - "shutdown - All tasks cancelled", level="warning", sources=["mudproxy"] - )() + LogRecord("shutdown - All tasks cancelled", level="warning", sources=["mudproxy"])() loop_.stop() @@ -372,9 +356,7 @@ def run_asynch() -> None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - LogRecord( - "__main__ - setting up signal handlers", level="info", sources=["mudproxy"] - )() + LogRecord("__main__ - setting up signal handlers", level="info", sources=["mudproxy"])() # for sig in (signal.SIGHUP, signal.SIGTERM, signal.SIGINT): for sig in [signal.SIGHUP, signal.SIGTERM, signal.SIGINT]: LogRecord( diff --git a/src/bastproxy/libs/callback.py b/src/bastproxy/libs/callback.py index 0323f3cf..51ba7851 100644 --- a/src/bastproxy/libs/callback.py +++ b/src/bastproxy/libs/callback.py @@ -49,9 +49,7 @@ class Callback: and the last execution time of the callback. """ - def __init__( - self, name: str, owner_id: str, func: Callable, enabled: bool = True - ) -> None: + def __init__(self, name: str, owner_id: str, func: Callable, enabled: bool = True) -> None: """Initialize the callback with the given parameters. Args: @@ -80,12 +78,7 @@ def __hash__(self) -> int: int: The generated hash value. """ - return ( - hash(self.func) - + hash(self.owner_id) - + hash(self.name) - + hash(self.created_time) - ) + return hash(self.func) + hash(self.owner_id) + hash(self.name) + hash(self.created_time) def __eq__(self, other_function: Any) -> bool: """Check equality between this callback and another function or callback. diff --git a/src/bastproxy/libs/net/__init__.py b/src/bastproxy/libs/net/__init__.py index b5a511c2..f605e926 100644 --- a/src/bastproxy/libs/net/__init__.py +++ b/src/bastproxy/libs/net/__init__.py @@ -42,8 +42,6 @@ for item, value in mud_protocols.items(): if not hasattr(telnetlib3.telopt, item): - LogRecord( - f"Adding {item} to telnetlib3.telopt", level="debug", sources=[__name__] - )() + LogRecord(f"Adding {item} to telnetlib3.telopt", level="debug", sources=[__name__])() setattr(telnetlib3.telopt, item, value) telnetlib3.telopt._DEBUG_OPTS[value] = item # type: ignore[assignment] diff --git a/src/bastproxy/libs/net/client.py b/src/bastproxy/libs/net/client.py index 0ea29fa6..e1f9f5c1 100644 --- a/src/bastproxy/libs/net/client.py +++ b/src/bastproxy/libs/net/client.py @@ -141,8 +141,7 @@ def send_to(self, data: NetworkDataLine) -> None: """ if not self.connected: LogRecord( - f"send_to - {self.uuid} [{self.addr}:{self.port}] is not connected. " - "Cannot send", + f"send_to - {self.uuid} [{self.addr}:{self.port}] is not connected. Cannot send", level="debug", sources=[__name__], )() @@ -183,11 +182,7 @@ async def setup_client(self) -> None: # We send an IAC+WILL+ECHO to the client so that # it won't locally echo the password. networkdata = NetworkData( - [ - NetworkDataLine( - telnet.echo_on(), line_type="COMMAND-TELNET", prelogin=True - ) - ], + [NetworkDataLine(telnet.echo_on(), line_type="COMMAND-TELNET", prelogin=True)], owner_id=f"client:{self.uuid}", ) SendDataDirectlyToClient(networkdata, clients=[self.uuid])() @@ -220,9 +215,7 @@ async def setup_client(self) -> None: [NetworkDataLine("Welcome to Bastproxy.", prelogin=True)], owner_id=f"client:{self.uuid}", ) - networkdata.append( - NetworkDataLine("Please enter your password.", prelogin=True) - ) + networkdata.append(NetworkDataLine("Please enter your password.", prelogin=True)) SendDataDirectlyToClient(networkdata, clients=[self.uuid])() self.login_attempts += 1 LogRecord( @@ -255,11 +248,7 @@ def process_data_from_not_logged_in_client(self, inp) -> None: vpw = self.api("plugins.core.proxy:ssc.proxypwview")() if inp.strip() == dpw: networkdata = NetworkData( - [ - NetworkDataLine( - telnet.echo_off(), line_type="COMMAND-TELNET", prelogin=True - ) - ], + [NetworkDataLine(telnet.echo_off(), line_type="COMMAND-TELNET", prelogin=True)], owner_id=f"client:{self.uuid}", ) networkdata.append(NetworkDataLine("You are now logged in.", prelogin=True)) @@ -267,17 +256,11 @@ def process_data_from_not_logged_in_client(self, inp) -> None: self.api("plugins.core.clients:client.logged.in")(self.uuid) elif inp.strip() == vpw: networkdata = NetworkData( - [ - NetworkDataLine( - telnet.echo_off(), line_type="COMMAND-TELNET", prelogin=True - ) - ], + [NetworkDataLine(telnet.echo_off(), line_type="COMMAND-TELNET", prelogin=True)], owner_id=f"client:{self.uuid}", ) networkdata.append( - NetworkDataLine( - "You are now logged in as view only user.", prelogin=True - ) + NetworkDataLine("You are now logged in as view only user.", prelogin=True) ) SendDataDirectlyToClient(networkdata, clients=[self.uuid])() self.api("plugins.core.clients:client.logged.in.view.only")(self.uuid) @@ -425,8 +408,7 @@ async def client_write(self) -> None: if msg_obj.is_io: if msg_obj.line: LogRecord( - f"client_write - Writing message to client {self.uuid}: " - f"{msg_obj.line}", + f"client_write - Writing message to client {self.uuid}: {msg_obj.line}", level="debug", sources=[__name__], )() @@ -446,9 +428,7 @@ async def client_write(self) -> None: )() if msg_obj.is_prompt: self.writer.write(telnet.go_ahead()) - self.data_logger.info( - "%-12s : %s", "client_write", telnet.go_ahead() - ) + self.data_logger.info("%-12s : %s", "client_write", telnet.go_ahead()) elif msg_obj.is_command_telnet: LogRecord( f"client_write - type of msg_obj.msg = {type(msg_obj.line)}", @@ -456,8 +436,7 @@ async def client_write(self) -> None: sources=[__name__], )() LogRecord( - f"client_write - Writing telnet option to client {self.uuid}: " - f"{msg_obj.line!r}", + f"client_write - Writing telnet option to client {self.uuid}: {msg_obj.line!r}", level="debug", sources=[__name__], )() @@ -558,9 +537,7 @@ async def unregister_client(connection) -> None: )() -async def client_telnet_handler( - reader: TelnetReaderUnicode, writer: TelnetWriterUnicode -) -> None: +async def client_telnet_handler(reader: TelnetReaderUnicode, writer: TelnetWriterUnicode) -> None: """Handle new telnet client connections. This coroutine handles new telnet client connections by creating a @@ -582,24 +559,17 @@ async def client_telnet_handler( client_details: str = writer.get_extra_info("peername") addr, port, *rest = client_details - connection: ClientConnection = ClientConnection( - addr, port, "telnet", reader, writer - ) + connection: ClientConnection = ClientConnection(addr, port, "telnet", reader, writer) LogRecord( - f"Connection established with {addr} : {port} : {rest} : uuid - " - f"{connection.uuid}", + f"Connection established with {addr} : {port} : {rest} : uuid - {connection.uuid}", level="warning", sources=[__name__], )() if await register_client(connection): tasks: list[asyncio.Task] = [ - TaskItem( - connection.client_read(), name=f"{connection.uuid} telnet read" - ).create(), - TaskItem( - connection.client_write(), name=f"{connection.uuid} telnet write" - ).create(), + TaskItem(connection.client_read(), name=f"{connection.uuid} telnet read").create(), + TaskItem(connection.client_write(), name=f"{connection.uuid} telnet write").create(), ] if current_task := asyncio.current_task(): diff --git a/src/bastproxy/libs/net/listeners.py b/src/bastproxy/libs/net/listeners.py index 7f761ddd..84e4e377 100644 --- a/src/bastproxy/libs/net/listeners.py +++ b/src/bastproxy/libs/net/listeners.py @@ -125,9 +125,7 @@ async def check_listeners_available(self) -> None: _ = self.ipv6_task.result self.ipv6_start = True - listen_port = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "listenport" - ) + listen_port = self.api("plugins.core.settings:get")("plugins.core.proxy", "listenport") if ipv4 and not self.ipv4_start: ipv4_address = self.api("plugins.core.settings:get")( "plugins.core.proxy", "ipv4address" @@ -174,11 +172,7 @@ async def check_listeners_available(self) -> None: msg + " and ".join(tlist) + " port " - + str( - self.api("plugins.core.settings:get")( - "plugins.core.proxy", "listenport" - ) - ) + + str(self.api("plugins.core.settings:get")("plugins.core.proxy", "listenport")) ) LogRecord(msg, level="info", sources=["mudproxy"])() @@ -200,21 +194,11 @@ def reset_listener_settings(self) -> None: None """ - self.api("plugins.core.settings:change")( - "plugins.core.proxy", "ipv4", "default" - ) - self.api("plugins.core.settings:change")( - "plugins.core.proxy", "ipv6", "default" - ) - self.api("plugins.core.settings:change")( - "plugins.core.proxy", "ipv4address", "default" - ) - self.api("plugins.core.settings:change")( - "plugins.core.proxy", "ipv6address", "default" - ) - self.api("plugins.core.settings:change")( - "plugins.core.proxy", "listenport", "default" - ) + self.api("plugins.core.settings:change")("plugins.core.proxy", "ipv4", "default") + self.api("plugins.core.settings:change")("plugins.core.proxy", "ipv6", "default") + self.api("plugins.core.settings:change")("plugins.core.proxy", "ipv4address", "default") + self.api("plugins.core.settings:change")("plugins.core.proxy", "ipv6address", "default") + self.api("plugins.core.settings:change")("plugins.core.proxy", "listenport", "default") def _create_listeners(self) -> None: """Create listeners for both IPv4 and IPv6 addresses. @@ -234,19 +218,13 @@ def _create_listeners(self) -> None: None """ - listen_port = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "listenport" - ) + listen_port = self.api("plugins.core.settings:get")("plugins.core.proxy", "listenport") ipv4 = self.api("plugins.core.settings:get")("plugins.core.proxy", "ipv4") - ipv4_address = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "ipv4address" - ) + ipv4_address = self.api("plugins.core.settings:get")("plugins.core.proxy", "ipv4address") ipv6 = self.api("plugins.core.settings:get")("plugins.core.proxy", "ipv6") - ipv6_address = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "ipv6address" - ) + ipv6_address = self.api("plugins.core.settings:get")("plugins.core.proxy", "ipv6address") if not ipv4 and not ipv6: LogRecord( diff --git a/src/bastproxy/libs/net/mud.py b/src/bastproxy/libs/net/mud.py index 2263713e..2498df37 100644 --- a/src/bastproxy/libs/net/mud.py +++ b/src/bastproxy/libs/net/mud.py @@ -155,8 +155,7 @@ def send_to(self, data: NetworkDataLine) -> None: loop = asyncio.get_event_loop() if not isinstance(data, NetworkDataLine): LogRecord( - "client: send_to - got a type that is not NetworkDataLine: " - f"{type(data)}", + f"client: send_to - got a type that is not NetworkDataLine: {type(data)}", level="error", stack_info=True, sources=[__name__], @@ -186,9 +185,7 @@ async def setup_mud(self) -> None: )() networkdata = NetworkData([], owner_id="mud:setup_mud") networkdata.append( - NetworkDataLine( - features, originated="internal", line_type="COMMAND-TELNET" - ) + NetworkDataLine(features, originated="internal", line_type="COMMAND-TELNET") ) SendDataDirectlyToMud(networkdata)() LogRecord( @@ -261,9 +258,7 @@ async def mud_read(self) -> None: level="debug", sources=[__name__], )() - data.append( - NetworkDataLine(inp, originated="mud", had_line_endings=False) - ) + data.append(NetworkDataLine(inp, originated="mud", had_line_endings=False)) logging.getLogger("data.mud").info("%-12s : %s", "from_mud", inp) if self.reader.at_eof(): # This is an EOF. Hard disconnect. @@ -318,9 +313,7 @@ async def mud_write(self) -> None: )() self.writer.write(msg_obj.line) msg_obj.was_sent = True - logging.getLogger("data.mud").info( - "%-12s : %s", "to_mud", msg_obj.line - ) + logging.getLogger("data.mud").info("%-12s : %s", "to_mud", msg_obj.line) else: LogRecord( "client_write - No message to write to client.", @@ -340,9 +333,7 @@ async def mud_write(self) -> None: )() self.writer.send_iac(msg_obj.line) msg_obj.was_sent = True - logging.getLogger("data.mud").info( - "%-12s : %s", "to_client", msg_obj.line - ) + logging.getLogger("data.mud").info("%-12s : %s", "to_client", msg_obj.line) if count >= self.max_lines_to_process: await asyncio.sleep(0) @@ -405,9 +396,7 @@ async def mud_telnet_handler( level="warning", sources=[__name__], )() - SendDataDirectlyToClient( - NetworkData(["Connection to the mud has been closed."]) - )() + SendDataDirectlyToClient(NetworkData(["Connection to the mud has been closed."]))() await asyncio.sleep(1) diff --git a/src/bastproxy/libs/net/telnet.py b/src/bastproxy/libs/net/telnet.py index 62aa42e2..0c7e9ea7 100644 --- a/src/bastproxy/libs/net/telnet.py +++ b/src/bastproxy/libs/net/telnet.py @@ -176,9 +176,7 @@ def split_opcode_from_input(data: bytes) -> tuple[bytes, str]: None """ - logging.getLogger(__name__).debug( - "Received raw data (len=%d of: %s", len(data), data - ) + logging.getLogger(__name__).debug("Received raw data (len=%d of: %s", len(data), data) opcodes = b"" inp = "" for position, _ in enumerate(data): @@ -307,9 +305,7 @@ async def handle(opcodes: bytes, writer: "TelnetWriterUnicode") -> None: for each_code in opcodes.split(IAC): if each_code and each_code in opcode_match: result = iac_sb(opcode_match[each_code]()) - logging.getLogger(__name__).debug( - "Responding to previous opcode with: %s", result - ) + logging.getLogger(__name__).debug("Responding to previous opcode with: %s", result) writer.write(result) await writer.drain() logging.getLogger(__name__).debug("Finished handling opcodes.") diff --git a/src/bastproxy/libs/persistentdict.py b/src/bastproxy/libs/persistentdict.py index 6627c979..54890a49 100644 --- a/src/bastproxy/libs/persistentdict.py +++ b/src/bastproxy/libs/persistentdict.py @@ -302,11 +302,7 @@ def pload(self) -> None: """ # try formats from most restrictive to least restrictive - if ( - self.file_name.exists() - and self.flag != "n" - and os.access(self.file_name, os.R_OK) - ): + if self.file_name.exists() and self.flag != "n" and os.access(self.file_name, os.R_OK): self.load() def load(self) -> None: diff --git a/src/bastproxy/libs/plugins/dependency.py b/src/bastproxy/libs/plugins/dependency.py index abc51b69..620ae808 100644 --- a/src/bastproxy/libs/plugins/dependency.py +++ b/src/bastproxy/libs/plugins/dependency.py @@ -152,10 +152,7 @@ def resolve_helper(self, plugin) -> None: and edge_plugin.plugin_id not in self.resolved ): if edge_plugin.plugin_id in self.unresolved: - msg = ( - f"Circular reference detected: {plugin.plugin_id} -> " - f"{plugin.plugin_id}" - ) + msg = f"Circular reference detected: {plugin.plugin_id} -> {plugin.plugin_id}" raise CircularDependencyError(msg) self.resolve_helper(edge_plugin) self.resolved.append(plugin.plugin_id) diff --git a/src/bastproxy/libs/plugins/imputils.py b/src/bastproxy/libs/plugins/imputils.py index 7a762e90..97375849 100644 --- a/src/bastproxy/libs/plugins/imputils.py +++ b/src/bastproxy/libs/plugins/imputils.py @@ -127,9 +127,7 @@ def on_error(package: str) -> None: """ errors[package] = sys.exc_info() - for module_info in pkgutil.walk_packages( - [directory.as_posix()], prefix, onerror=on_error - ): + for module_info in pkgutil.walk_packages([directory.as_posix()], prefix, onerror=on_error): if module_info.ispkg and (tspec := find_spec(module_info.name)): loader_path: str = ( tspec.loader.path # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] diff --git a/src/bastproxy/libs/plugins/loader.py b/src/bastproxy/libs/plugins/loader.py index 3cf78f51..a625a1d0 100644 --- a/src/bastproxy/libs/plugins/loader.py +++ b/src/bastproxy/libs/plugins/loader.py @@ -239,9 +239,7 @@ def _api_does_plugin_exist(self, plugin_id: str) -> bool: @AddAPI( "plugin.get.changed.files", - description=( - "get the list of files that have changed since loading for a plugin" - ), + description=("get the list of files that have changed since loading for a plugin"), ) def _api_plugin_get_changed_files(self, plugin: str) -> list[dict]: """Get the list of files that have changed since loading for a plugin. @@ -263,9 +261,7 @@ def _api_plugin_get_changed_files(self, plugin: str) -> list[dict]: @AddAPI( "plugin.get.invalid.python.files", - description=( - "get the list of files that have invalid python syntax for a plugin" - ), + description=("get the list of files that have invalid python syntax for a plugin"), ) def _api_plugin_get_invalid_python_files(self, plugin: str) -> list[dict]: """Get the list of files that have invalid Python syntax for a plugin. @@ -334,9 +330,7 @@ def _api_get_packages_list(self, active_only: bool = False) -> list[str]: if plugin_info.runtime_info.is_loaded ] else: - packages = [ - plugin_info.package for plugin_info in self.plugins_info.values() - ] + packages = [plugin_info.package for plugin_info in self.plugins_info.values()] return list(set(packages)) @@ -392,9 +386,7 @@ def _api_get_plugin_instance(self, plugin_name: str) -> BasePlugin | None: plugin_name in self.plugins_info and self.plugins_info[plugin_name].runtime_info.is_loaded ): - plugin_instance = self.plugins_info[ - plugin_name - ].runtime_info.plugin_instance + plugin_instance = self.plugins_info[plugin_name].runtime_info.plugin_instance elif isinstance(plugin_name, BasePlugin): plugin_instance = plugin_name @@ -542,12 +534,8 @@ def update_all_plugin_information(self) -> None: plugin_info = PluginInfo(plugin_id=found_plugin["plugin_id"]) plugin_info.package_init_file_path = found_plugin["package_init_file_path"] plugin_info.package_path = found_plugin["package_path"] - plugin_info.package_import_location = found_plugin[ - "package_import_location" - ] - plugin_info.data_directory = ( - self.api.BASEDATAPLUGINPATH / plugin_info.plugin_id - ) + plugin_info.package_import_location = found_plugin["package_import_location"] + plugin_info.data_directory = self.api.BASEDATAPLUGINPATH / plugin_info.plugin_id plugin_info.update_from_init() @@ -555,9 +543,7 @@ def update_all_plugin_information(self) -> None: plugin_info.is_required = True if plugin_info.package_import_location in errors: - plugin_info.import_errors.append( - errors[plugin_info.package_import_location] - ) + plugin_info.import_errors.append(errors[plugin_info.package_import_location]) plugin_info.get_file_data() @@ -568,10 +554,7 @@ def update_all_plugin_information(self) -> None: # warn about plugins whose path is no longer valid removed_plugins = set(old_plugins_info.keys()) - set(self.plugins_info.keys()) for plugin_id in removed_plugins: - if ( - plugin_id in old_plugins_info - and old_plugins_info[plugin_id].runtime_info - ): + if plugin_id in old_plugins_info and old_plugins_info[plugin_id].runtime_info: LogRecord( [ f"Loaded Plugin {plugin_id}'s path is no longer valid: " @@ -582,9 +565,7 @@ def update_all_plugin_information(self) -> None: sources=[__name__], )() - def _import_single_plugin( - self, plugin_id: str, exit_on_error: bool = False - ) -> bool: + def _import_single_plugin(self, plugin_id: str, exit_on_error: bool = False) -> bool: """Import a single plugin. This method imports a single plugin based on the provided plugin ID. It @@ -604,9 +585,7 @@ def _import_single_plugin( """ # import the plugin - LogRecord( - f"{plugin_id:<30} : attempting import", level="info", sources=[__name__] - )() + LogRecord(f"{plugin_id:<30} : attempting import", level="info", sources=[__name__])() plugin_info = self.plugins_info[plugin_id] plugin_info.update_from_init() return_info = imputils.importmodule(plugin_info.plugin_class_import_location) @@ -634,9 +613,7 @@ def _import_single_plugin( plugin_info.runtime_info.is_imported = True plugin_info.runtime_info.imported_time = datetime.datetime.now(datetime.UTC) - LogRecord( - f"{plugin_id:<30} : imported successfully", level="info", sources=[__name__] - )() + LogRecord(f"{plugin_id:<30} : imported successfully", level="info", sources=[__name__])() # check for patches to the base plugin if ( @@ -670,9 +647,7 @@ def _import_single_plugin( return True - def _instantiate_single_plugin( - self, plugin_id: str, exit_on_error: bool = False - ) -> bool: + def _instantiate_single_plugin(self, plugin_id: str, exit_on_error: bool = False) -> bool: """Instantiate a single plugin. This method creates an instance of a plugin based on the provided plugin ID. @@ -691,9 +666,7 @@ def _instantiate_single_plugin( """ plugin_info = self.plugins_info[plugin_id] - LogRecord( - f"{plugin_id:<30} : creating instance", level="info", sources=[__name__] - )() + LogRecord(f"{plugin_id:<30} : creating instance", level="info", sources=[__name__])() if not plugin_info.plugin_class_import_location: LogRecord( @@ -734,9 +707,7 @@ def _instantiate_single_plugin( return True # run the initialize method for a plugin - def _run_initialize_single_plugin( - self, plugin_id: str, exit_on_error: bool = False - ) -> bool: + def _run_initialize_single_plugin(self, plugin_id: str, exit_on_error: bool = False) -> bool: """Run the initialize method for a single plugin. This method runs the initialize method for a single plugin based on the @@ -768,8 +739,7 @@ def _run_initialize_single_plugin( if not plugin_info.runtime_info.plugin_instance: LogRecord( - f"{plugin_info.plugin_id:<30} : plugin instance is None, not " - f"initializing", + f"{plugin_info.plugin_id:<30} : plugin instance is None, not initializing", level="error", sources=[__name__, plugin_info.plugin_id], )() @@ -788,8 +758,7 @@ def _run_initialize_single_plugin( )() if exit_on_error: LogRecord( - f"{plugin_info.plugin_id:<30} : INITIALIZE METHOD WAS NOT " - "SUCCESSFUL", + f"{plugin_info.plugin_id:<30} : INITIALIZE METHOD WAS NOT SUCCESSFUL", level="error", sources=[__name__, plugin_info.plugin_id], )() @@ -860,22 +829,16 @@ def _api_load_plugins( bad_plugins.append(plugin_id) plugins_not_loaded = [ - plugin_id - for plugin_id in plugins_not_loaded - if plugin_id not in bad_plugins + plugin_id for plugin_id in plugins_not_loaded if plugin_id not in bad_plugins ] # instantiate plugins for plugin_id in plugins_not_loaded: - if not self._instantiate_single_plugin( - plugin_id, exit_on_error=exit_on_error - ): + if not self._instantiate_single_plugin(plugin_id, exit_on_error=exit_on_error): bad_plugins.append(plugin_id) plugins_not_loaded = [ - plugin_id - for plugin_id in plugins_not_loaded - if plugin_id not in bad_plugins + plugin_id for plugin_id in plugins_not_loaded if plugin_id not in bad_plugins ] # check dependencies @@ -897,22 +860,16 @@ def _api_load_plugins( # plugins_not_loaded = [ - plugin_id - for plugin_id in plugins_not_loaded - if plugin_id not in bad_plugins + plugin_id for plugin_id in plugins_not_loaded if plugin_id not in bad_plugins ] # run the initialize method for each plugin for plugin_id in plugins_not_loaded: - if not self._run_initialize_single_plugin( - plugin_id, exit_on_error=exit_on_error - ): + if not self._run_initialize_single_plugin(plugin_id, exit_on_error=exit_on_error): bad_plugins.append(plugin_id) loaded_plugins = [ - plugin_id - for plugin_id in plugins_not_loaded - if plugin_id not in bad_plugins + plugin_id for plugin_id in plugins_not_loaded if plugin_id not in bad_plugins ] # clean up plugins that @@ -1035,16 +992,14 @@ def _api_unload_plugin(self, plugin_id: str) -> bool: )() else: LogRecord( - f"{plugin_info.plugin_id:<30} : plugin instance not found " - f"({plugin_info.name})", + f"{plugin_info.plugin_id:<30} : plugin instance not found ({plugin_info.name})", level="info", sources=[__name__, plugin_info.plugin_id], )() except Exception: # pylint: disable=broad-except LogRecord( - f"unload: error running the uninitialize method for " - f"{plugin_info.plugin_id}", + f"unload: error running the uninitialize method for {plugin_info.plugin_id}", level="error", sources=[__name__, plugin_info.plugin_id], exc_info=True, @@ -1076,9 +1031,7 @@ def _api_unload_plugin(self, plugin_id: str) -> bool: for item in modules_to_delete: cb_weakref = partial(self.remove_weakref, module_import_path=item) - self.weak_references_to_modules[item] = weakref.ref( - sys.modules[item], cb_weakref - ) + self.weak_references_to_modules[item] = weakref.ref(sys.modules[item], cb_weakref) if imputils.deletemodule(item): LogRecord( f"{plugin_info.plugin_id:<30} : deleting imported module {item} " @@ -1169,9 +1122,7 @@ def load_plugins_on_startup(self) -> None: """ LogRecord("Loading core and client plugins", level="info", sources=[__name__])() self._load_core_and_client_plugins_on_startup() - LogRecord( - "Finished Loading core and client plugins", level="info", sources=[__name__] - )() + LogRecord("Finished Loading core and client plugins", level="info", sources=[__name__])() LogRecord( f"ev_{__name__}_post_startup_plugins_loaded: Started", @@ -1179,9 +1130,7 @@ def load_plugins_on_startup(self) -> None: sources=[__name__], )() - self.api("plugins.core.events:raise.event")( - f"ev_{__name__}_post_startup_plugins_loaded" - ) + self.api("plugins.core.events:raise.event")(f"ev_{__name__}_post_startup_plugins_loaded") LogRecord( f"ev_{__name__}_post_startup_plugins_loaded: Finish", @@ -1195,9 +1144,7 @@ def load_plugins_on_startup(self) -> None: for error in plugin_info.import_errors: traceback_message = traceback.format_exception(error[1]) traceback_message = [ - item.strip() - for item in traceback_message - if item and item != "\n" + item.strip() for item in traceback_message if item and item != "\n" ] LogRecord( [ diff --git a/src/bastproxy/libs/plugins/plugininfo.py b/src/bastproxy/libs/plugins/plugininfo.py index d973f1fe..56345dfe 100644 --- a/src/bastproxy/libs/plugins/plugininfo.py +++ b/src/bastproxy/libs/plugins/plugininfo.py @@ -73,9 +73,7 @@ def __init__(self) -> None: # The plugin instance self.plugin_instance: None | BasePlugin = None # The imported time - self.imported_time: datetime.datetime = datetime.datetime( - 1970, 1, 1, tzinfo=datetime.UTC - ) + self.imported_time: datetime.datetime = datetime.datetime(1970, 1, 1, tzinfo=datetime.UTC) class PluginInfo: @@ -115,9 +113,7 @@ def __init__(self, plugin_id: str) -> None: self.runtime_info: PluginRuntimeInfo = PluginRuntimeInfo() self.import_errors: list = [] - def check_file_is_valid_python_code( - self, file: Path - ) -> tuple[bool, Exception | None]: + def check_file_is_valid_python_code(self, file: Path) -> tuple[bool, Exception | None]: """Check if a file contains valid Python code. This method attempts to parse the provided file to determine if it contains @@ -164,9 +160,7 @@ def _get_files_by_flag_helper(self, files: dict, flag: str) -> list[dict[str, An changed_files = [] if "files" in files: changed_files.extend( - files["files"][file] - for file in files["files"] - if files["files"][file][flag] + files["files"][file] for file in files["files"] if files["files"][file][flag] ) for item, value in files.items(): @@ -245,8 +239,7 @@ def get_file_data(self) -> dict: if ( parent_dir in oldfiles and file.name in oldfiles[parent_dir] - and file_modified_time - == oldfiles[parent_dir][file.name]["modified_time"] + and file_modified_time == oldfiles[parent_dir][file.name]["modified_time"] ): self.files[parent_dir][file.name] = oldfiles[parent_dir][file.name] continue diff --git a/src/bastproxy/libs/records/managers/records.py b/src/bastproxy/libs/records/managers/records.py index 012d73af..6f1edcd0 100644 --- a/src/bastproxy/libs/records/managers/records.py +++ b/src/bastproxy/libs/records/managers/records.py @@ -195,9 +195,7 @@ def format_all_children_helper( output.append(f"{pre_string} |-> {child.one_line_summary()}") if not all_children: emptybar[indent] = True - self.format_all_children_helper( - children[child], indent + 1, emptybar, output, rfilter - ) + self.format_all_children_helper(children[child], indent + 1, emptybar, output, rfilter) return output def add(self, record): diff --git a/src/bastproxy/libs/records/rtypes/base.py b/src/bastproxy/libs/records/rtypes/base.py index ccc5be73..07f77153 100644 --- a/src/bastproxy/libs/records/rtypes/base.py +++ b/src/bastproxy/libs/records/rtypes/base.py @@ -43,9 +43,7 @@ def __init__(self, owner_id: str = "", track_record=True, parent=None): self.uuid = uuid4().hex self.owner_id = owner_id or f"{self.__class__.__name__}:{self.uuid}" # Add an API - self.api = API( - owner_id=self.owner_id or f"{self.__class__.__name__}:{self.uuid}" - ) + self.api = API(owner_id=self.owner_id or f"{self.__class__.__name__}:{self.uuid}") self.created = datetime.datetime.now(datetime.UTC) self.updates = UpdateManager() self.execute_time_taken = -1 @@ -231,12 +229,8 @@ def format_children(self, full_children_records=False, update_filter=None): """ msg = [] if full_children_records: - children_records = RMANAGER.get_all_children_dict( - self, record_filter=update_filter - ) - msg.extend( - ["Children Records :", "---------------------------------------"] - ) + children_records = RMANAGER.get_all_children_dict(self, record_filter=update_filter) + msg.extend(["Children Records :", "---------------------------------------"]) for record in children_records: msg.extend( f" {line}" @@ -274,11 +268,7 @@ def get_formatted_details( for level in attributes: for item_string, item_attr in attributes[level]: if isinstance(item_attr, str): - attr = ( - getattr(self, item_attr) - if hasattr(self, item_attr) - else item_attr - ) + attr = getattr(self, item_attr) if hasattr(self, item_attr) else item_attr else: attr = item_attr if isinstance(attr, (list, dict)): @@ -314,10 +304,7 @@ def get_formatted_details( def check_for_change(self, flag: str, action: str): """Check if there is a change with the given flag and action.""" - return any( - update["flag"] == flag and update["action"] == action - for update in self.updates - ) + return any(update["flag"] == flag and update["action"] == action for update in self.updates) # def __str__(self): # return f"{self.__class__.__name__}:{self.uuid})" @@ -398,9 +385,7 @@ def __setitem__(self, index, item): ) return super().__setitem__(index, item) - self.addupdate( - "Modify", f"set item at position {index}", extra={"item": f"{item!r}"} - ) + self.addupdate("Modify", f"set item at position {index}", extra={"item": f"{item!r}"}) def insert(self, index, item): """Insert an item.""" @@ -500,9 +485,7 @@ def get_attributes_to_format(self): """ attributes = super().get_attributes_to_format() - attributes[0].extend( - [("Internal", "internal"), ("Message Type", "message_type")] - ) + attributes[0].extend([("Internal", "internal"), ("Message Type", "message_type")]) attributes[2].append(("Data", "data")) if self.original_data != self.data: attributes[2].append(("Original Data", "original_data")) @@ -560,9 +543,7 @@ def color_lines(self, color: str, actor=""): colored_line = f"@w{color}".join(new_line_list) if colored_line: colored_line = f"{color}{colored_line}@w" - new_message.append( - self.api("plugins.core.colors:colorcode.to.ansicode")(colored_line) - ) + new_message.append(self.api("plugins.core.colors:colorcode.to.ansicode")(colored_line)) self.replace( new_message, @@ -594,13 +575,9 @@ def clean(self, actor: str = ""): level="error", sources=[__name__], )() - self.replace( - new_message, actor=f"{actor}:clean", extra={"msg": "clean each item"} - ) + self.replace(new_message, actor=f"{actor}:clean", extra={"msg": "clean each item"}) - def addupdate( - self, flag: str, action: str, extra: dict | None = None, savedata: bool = True - ): + def addupdate(self, flag: str, action: str, extra: dict | None = None, savedata: bool = True): """Add a change event for this record. flag: one of 'Modify', 'Set Flag', 'Info' @@ -657,9 +634,7 @@ def get_attributes_to_format(self): return attributes - def addupdate( - self, flag: str, action: str, extra: dict | None = None, savedata: bool = True - ): + def addupdate(self, flag: str, action: str, extra: dict | None = None, savedata: bool = True): """Add a change event for this record. flag: one of 'Modify', 'Set Flag', 'Info' diff --git a/src/bastproxy/libs/records/rtypes/clientdata.py b/src/bastproxy/libs/records/rtypes/clientdata.py index 57f0958d..63bfbe22 100644 --- a/src/bastproxy/libs/records/rtypes/clientdata.py +++ b/src/bastproxy/libs/records/rtypes/clientdata.py @@ -93,9 +93,7 @@ def setup_events(self): self.modify_data_event_name, __name__, description=["An event to modify data before it is sent to the client"], - arg_descriptions={ - "line": "The line to modify, a NetworkDataLine object" - }, + arg_descriptions={"line": "The line to modify, a NetworkDataLine object"}, ) # @property @@ -144,16 +142,12 @@ def can_send_to_client(self, client_uuid, internal): return False # If the client is a view client and this is an internal message, we don't send it # This way view clients don't see the output of commands entered by other clients - if ( - self.api("plugins.core.clients:client.is.view.client")(client_uuid) - and internal - ): + if self.api("plugins.core.clients:client.is.view.client")(client_uuid) and internal: return False # If the client is in the list of clients or self.clients is empty, # then we can check to make sure the client is logged in or the prelogin flag is set if (not self.clients or client_uuid in self.clients) and ( - self.api("plugins.core.clients:client.is.logged.in")(client_uuid) - or self.prelogin + self.api("plugins.core.clients:client.is.logged.in")(client_uuid) or self.prelogin ): # All checks passed, we can send to this client return True @@ -163,9 +157,7 @@ def _exec_(self): """Send the message.""" # If a line came from the mud and it is not a telnet command, # pass each line through the event system to allow plugins to modify it - if data_for_event := [ - line for line in self.message if line.frommud and line.is_io - ]: + if data_for_event := [line for line in self.message if line.frommud and line.is_io]: self.api("plugins.core.events:raise.event")( self.modify_data_event_name, data_list=data_for_event, key_name="line" ) @@ -217,9 +209,7 @@ def setup_events(self): self.read_data_event_name, __name__, description=["An event to see data that was sent to the client"], - arg_descriptions={ - "line": "The line to modify, a NetworkDataLine object" - }, + arg_descriptions={"line": "The line to modify, a NetworkDataLine object"}, ) def one_line_summary(self): @@ -243,8 +233,7 @@ def can_send_to_client(self, client_uuid, line): # If the client is in the list of clients or self.clients is empty, # then we can check to make sure the client is logged in or the prelogin flag is set if (not self.clients or client_uuid in self.clients) and ( - self.api("plugins.core.clients:client.is.logged.in")(client_uuid) - or line.prelogin + self.api("plugins.core.clients:client.is.logged.in")(client_uuid) or line.prelogin ): # All checks passed, we can send to this client return True @@ -258,14 +247,12 @@ def _exec_(self): line.format() line.lock() - clients = self.clients or self.api( - "plugins.core.clients:get.all.clients" - )(uuid_only=True) + clients = self.clients or self.api("plugins.core.clients:get.all.clients")( + uuid_only=True + ) for client_uuid in clients: if self.can_send_to_client(client_uuid, line): - self.api("plugins.core.clients:send.to.client")( - client_uuid, line - ) + self.api("plugins.core.clients:send.to.client")(client_uuid, line) else: LogRecord( f"## NOTE: Client {client_uuid} cannot receive message {self.uuid!s}", diff --git a/src/bastproxy/libs/records/rtypes/muddata.py b/src/bastproxy/libs/records/rtypes/muddata.py index 71062691..88d63c84 100644 --- a/src/bastproxy/libs/records/rtypes/muddata.py +++ b/src/bastproxy/libs/records/rtypes/muddata.py @@ -56,9 +56,7 @@ def get_attributes_to_format(self): """ attributes = super().get_attributes_to_format() - attributes[0].extend( - [("Show in History", "show_in_history"), ("Client ID", "client_id")] - ) + attributes[0].extend([("Show in History", "show_in_history"), ("Client ID", "client_id")]) return attributes def one_line_summary(self): @@ -126,9 +124,7 @@ def _exec_(self): # pass each line through the event system to allow plugins to modify it self.seperate_commands() - if data_for_event := [ - line for line in self.message if line.fromclient and line.is_io - ]: + if data_for_event := [line for line in self.message if line.fromclient and line.is_io]: self.api("plugins.core.events:raise.event")( self.modify_data_event_name, event_args={ @@ -177,9 +173,7 @@ def setup_events(self): self.read_data_event_name, __name__, description=["An event to see data that was sent to the mud"], - arg_descriptions={ - "line": "The line to modify, a NetworkDataLine instance" - }, + arg_descriptions={"line": "The line to modify, a NetworkDataLine instance"}, ) def _exec_(self): diff --git a/src/bastproxy/libs/records/rtypes/networkdata.py b/src/bastproxy/libs/records/rtypes/networkdata.py index a113e51f..99488663 100644 --- a/src/bastproxy/libs/records/rtypes/networkdata.py +++ b/src/bastproxy/libs/records/rtypes/networkdata.py @@ -53,10 +53,7 @@ def __init__( self._attributes_to_monitor.append("was_sent") if originated != "internal" and ( (isinstance(line, str) and ("\n" in line or "\r" in line)) - or ( - isinstance(line, (bytes, bytearray)) - and (b"\n" in line or b"\r" in line) - ) + or (isinstance(line, (bytes, bytearray)) and (b"\n" in line or b"\r" in line)) ): LogRecord( f"NetworkDataLine: {self.uuid} {line} is multi line with \\n and/or \\r", @@ -268,9 +265,7 @@ def one_line_summary(self): A formatted summary string. """ - return ( - f"{self.__class__.__name__:<20} {self.uuid} {self.originated} {self.line!r}" - ) + return f"{self.__class__.__name__:<20} {self.uuid} {self.originated} {self.line!r}" def __str__(self): """Return a string representation of the line. @@ -293,9 +288,7 @@ def __repr__(self): def add_preamble(self, error: bool = False): """Add the preamble to the line only if it is from internal and is an IO message.""" if self.internal and self.is_io: - preamblecolor = self.api("plugins.core.proxy:preamble.color.get")( - error=error - ) + preamblecolor = self.api("plugins.core.proxy:preamble.color.get")(error=error) preambletext = self.api("plugins.core.proxy:preamble.get")() self.line = f"{preamblecolor}{preambletext}@w: {self.line}" @@ -306,13 +299,7 @@ class NetworkData(TrackedUserList): def __init__( self, message: ( - NetworkDataLine - | str - | bytes - | list[NetworkDataLine] - | list[str] - | list[bytes] - | None + NetworkDataLine | str | bytes | list[NetworkDataLine] | list[str] | list[bytes] | None ) = None, owner_id: str = "", ): @@ -351,8 +338,7 @@ def get_first_line(self): ( networkline.original_line for networkline in self - if networkline.original_line - not in ["#BP", b"#BP", "", b"", "''", b"''"] + if networkline.original_line not in ["#BP", b"#BP", "", b"", "''", b"''"] ), "", ) @@ -400,7 +386,9 @@ def extend(self, items: list[NetworkDataLine | str | bytes | bytearray]): new_list = [] for item in items: if not (isinstance(item, (NetworkDataLine, str, bytes, bytearray))): - msg = f"item must be a NetworkDataLine object or a string, not {type(item)} {item!r}" + msg = ( + f"item must be a NetworkDataLine object or a string, not {type(item)} {item!r}" + ) raise TypeError(msg) if isinstance(item, (str, bytes, bytearray)): converted_item = NetworkDataLine(item) diff --git a/src/bastproxy/libs/records/rtypes/update.py b/src/bastproxy/libs/records/rtypes/update.py index b98d1420..16c8fbb9 100644 --- a/src/bastproxy/libs/records/rtypes/update.py +++ b/src/bastproxy/libs/records/rtypes/update.py @@ -30,9 +30,7 @@ class UpdateRecord: will automatically add the time and last 5 stack frames """ - def __init__( - self, parent, flag: str, action: str, extra: dict | None = None, data=None - ): + def __init__(self, parent, flag: str, action: str, extra: dict | None = None, data=None): """Initialize an update record. Args: @@ -193,8 +191,7 @@ def format_detailed( data = self.data tmsg.append(f"{'Data':<15} :") tmsg.extend( - f"{'':<15} : {line}" - for line in pprint.pformat(data, width=120).splitlines() + f"{'':<15} : {line}" for line in pprint.pformat(data, width=120).splitlines() ) if show_stack and self.stack: tmsg.append(f"{'Stack':<15} :") diff --git a/src/bastproxy/libs/timing.py b/src/bastproxy/libs/timing.py index b59e9070..abfbc6ff 100644 --- a/src/bastproxy/libs/timing.py +++ b/src/bastproxy/libs/timing.py @@ -172,8 +172,7 @@ def _api_start(self, timername: str = "", args: Any | None = None) -> str | None "args": args, } LogRecord( - f"starttimer - {uid} {timername:<20} : started - from {owner_id} " - f"with args {args}", + f"starttimer - {uid} {timername:<20} : started - from {owner_id} with args {args}", level="debug", sources=[__name__, owner_id], )() @@ -213,8 +212,7 @@ def _api_finish(self, uid: str) -> float | None: )() else: LogRecord( - f"finishtimer - {uid} {timername:<20} : finished in " - f"{time_taken} ms", + f"finishtimer - {uid} {timername:<20} : finished in {time_taken} ms", level="debug", sources=[__name__, self.timing[uid]["owner_id"]], )() diff --git a/src/bastproxy/libs/tracking/utils/attributes.py b/src/bastproxy/libs/tracking/utils/attributes.py index 36951679..873239be 100644 --- a/src/bastproxy/libs/tracking/utils/attributes.py +++ b/src/bastproxy/libs/tracking/utils/attributes.py @@ -57,10 +57,7 @@ def __setattr__(self, name, value): return super().__setattr__(name, value) - if ( - hasattr(self, "_attributes_to_monitor") - and name in self._attributes_to_monitor - ): + if hasattr(self, "_attributes_to_monitor") and name in self._attributes_to_monitor: self._attribute_set(name, original_value, value) def _attribute_set(self, name, original_value, new_value): diff --git a/src/bastproxy/plugins/_baseplugin/_base.py b/src/bastproxy/plugins/_baseplugin/_base.py index b238d578..f7fe8cd8 100644 --- a/src/bastproxy/plugins/_baseplugin/_base.py +++ b/src/bastproxy/plugins/_baseplugin/_base.py @@ -77,9 +77,7 @@ def _phook_base_init(self): # load anything in the reload cache cache = self.api("libs.plugins.reloadutils:get.plugin.cache")(self.plugin_id) for item in cache: - LogRecord( - f"loading {item} from cache", level="debug", sources=[self.plugin_id] - )() + LogRecord(f"loading {item} from cache", level="debug", sources=[self.plugin_id])() self.__setattr__(item, cache[item]) self.api("libs.plugins.reloadutils:remove.plugin.cache")(self.plugin_id) @@ -126,9 +124,7 @@ def _get_all_plugin_hook_functions(self) -> dict: function_list[plugin_hook] = {} if item_plugin_hooks[plugin_hook] not in function_list[plugin_hook]: function_list[plugin_hook][item_plugin_hooks[plugin_hook]] = [] - function_list[plugin_hook][ - item_plugin_hooks[plugin_hook] - ].append(attr) + function_list[plugin_hook][item_plugin_hooks[plugin_hook]].append(attr) return function_list @@ -310,13 +306,9 @@ def _api_dump(self, attribute_name, detailed=False): with contextlib.suppress(TypeError): if callable(attr): - message.extend( - (header, f"Defined in {inspect.getfile(attr)}", header, "") - ) + message.extend((header, f"Defined in {inspect.getfile(attr)}", header, "")) text_list, _ = inspect.getsourcelines(attr) - message.extend( - [i.replace("@", "@@").rstrip("\n") for i in text_list] - ) + message.extend([i.replace("@", "@@").rstrip("\n") for i in text_list]) return True, message @@ -329,16 +321,11 @@ def _phook_base_post_initialize_update_version(self): old_plugin_version - the version in the savestate file new_plugin_version - the latest version from the module """ - old_plugin_version = self.api("plugins.core.settings:get")( - self.plugin_id, "_version" - ) + old_plugin_version = self.api("plugins.core.settings:get")(self.plugin_id, "_version") new_plugin_version = self.plugin_info.version - if ( - old_plugin_version != new_plugin_version - and new_plugin_version > old_plugin_version - ): + if old_plugin_version != new_plugin_version and new_plugin_version > old_plugin_version: for version in range(old_plugin_version + 1, new_plugin_version + 1): LogRecord( f"_update_version: upgrading to version {version}", @@ -354,9 +341,7 @@ def _phook_base_post_initialize_update_version(self): sources=[self.plugin_id, "plugin_upgrade"], )() - self.api("plugins.core.settings:change")( - self.plugin_id, "_version", new_plugin_version - ) + self.api("plugins.core.settings:change")(self.plugin_id, "_version", new_plugin_version) self.api(f"{self.plugin_id}:save.state")() diff --git a/src/bastproxy/plugins/_baseplugin/_commands.py b/src/bastproxy/plugins/_baseplugin/_commands.py index 8e1f0b4e..4e9b0ab7 100644 --- a/src/bastproxy/plugins/_baseplugin/_commands.py +++ b/src/bastproxy/plugins/_baseplugin/_commands.py @@ -41,9 +41,7 @@ def _phook_base_post_initialize_add_reset_command(self: "Plugin"): help="show functions this plugin has in the api", action="store_true", ) - @AddArgument( - "-c", "--commands", help="show commands in this plugin", action="store_true" - ) + @AddArgument("-c", "--commands", help="show commands in this plugin", action="store_true") def _command_help(self: "Plugin"): """@G%(name)s@w - @B%(cmdname)s@w. @@ -73,9 +71,7 @@ def _command_help(self: "Plugin"): msg.pop() file_header = False - for file in self.api("libs.plugins.loader:plugin.get.changed.files")( - self.plugin_id - ): + for file in self.api("libs.plugins.loader:plugin.get.changed.files")(self.plugin_id): if not file_header: file_header = True if msg[-1] != "": @@ -93,9 +89,7 @@ def _command_help(self: "Plugin"): msg.append("@B" + "-" * 60 + "@w") file_header = False - for file in self.api("libs.plugins.loader:plugin.get.invalid.python.files")( - self.plugin_id - ): + for file in self.api("libs.plugins.loader:plugin.get.invalid.python.files")(self.plugin_id): if not file_header: file_header = True if msg[-1] != "": @@ -116,9 +110,7 @@ def _command_help(self: "Plugin"): msg.append("") if args["commands"]: - cmd_output = self.api("plugins.core.commands:list.commands.formatted")( - self.plugin_id - ) + cmd_output = self.api("plugins.core.commands:list.commands.formatted")(self.plugin_id) msg.extend(cmd_output) msg.extend(("@G" + "-" * 60 + "@w", "")) if args["api"] and (api_list := self.api("libs.api:list")(self.plugin_id)): diff --git a/src/bastproxy/plugins/_baseplugin/_patch.py b/src/bastproxy/plugins/_baseplugin/_patch.py index dd590718..1689154b 100644 --- a/src/bastproxy/plugins/_baseplugin/_patch.py +++ b/src/bastproxy/plugins/_baseplugin/_patch.py @@ -63,9 +63,7 @@ def patch(full_import_location, override=False): )() continue - LogRecord( - f"adding {itemo.__name__}", level="info", sources=["baseplugin"] - )() + LogRecord(f"adding {itemo.__name__}", level="info", sources=["baseplugin"])() setattr(Plugin, itemo.__name__, itemo) added = True diff --git a/src/bastproxy/plugins/core/clients/plugin/_clients.py b/src/bastproxy/plugins/core/clients/plugin/_clients.py index 3cbe68b0..c2ee23a0 100644 --- a/src/bastproxy/plugins/core/clients/plugin/_clients.py +++ b/src/bastproxy/plugins/core/clients/plugin/_clients.py @@ -19,9 +19,7 @@ class BanRecord: - def __init__( - self, plugin_id: str, ip_addr: str, how_long: int = 600, copy: bool = False - ): + def __init__(self, plugin_id: str, ip_addr: str, how_long: int = 600, copy: bool = False): self.api = API(owner_id=f"{plugin_id}:Ban:{ip_addr}") self.plugin_id: str = plugin_id self.ip_addr: str = ip_addr @@ -42,9 +40,7 @@ def __init__( @property def expires(self) -> str: - if next_fire := self.api("plugins.core.timers:get.timer.next.fire")( - self.timer_name - ): + if next_fire := self.api("plugins.core.timers:get.timer.next.fire")(self.timer_name): return next_fire.strftime(self.api.time_format) return "Permanent" @@ -90,9 +86,7 @@ def _phook_initialize(self): self.api("plugins.core.events:add.event")( f"ev_{self.plugin_id}_client_logged_in_view_only", self.plugin_id, - description=[ - "An event that is raised when a client logs in as a view client" - ], + description=["An event that is raised when a client logs in as a view client"], arg_descriptions={"client_uuid": "the uuid of the client"}, ) self.api("plugins.core.events:add.event")( @@ -148,14 +142,10 @@ def _api_client_banned_add(self, client_uuid, how_long=600): def _api_client_banned_add_by_ip(self, ip_address, how_long): """Add a banned ip.""" if how_long == -1: - permbanips = self.api("plugins.core.settings:get")( - self.plugin_id, "permbanips" - ) + permbanips = self.api("plugins.core.settings:get")(self.plugin_id, "permbanips") if ip_address not in permbanips: permbanips.append(ip_address) - self.api("plugins.core.settings:change")( - self.plugin_id, "permbanips", permbanips - ) + self.api("plugins.core.settings:change")(self.plugin_id, "permbanips", permbanips) LogRecord( f"{ip_address} has been banned with no expiration", level="error", @@ -200,9 +190,7 @@ def _api_client_banned_remove(self, addr, auto=False): permbanips = self.api("plugins.core.settings:get")(self.plugin_id, "permbanips") if addr in permbanips: permbanips.remove(addr) - self.api("plugins.core.settings:change")( - self.plugin_id, "permbanips", permbanips - ) + self.api("plugins.core.settings:change")(self.plugin_id, "permbanips", permbanips) LogRecord(msg, level="error", sources=[self.plugin_id])() return True @@ -219,9 +207,7 @@ def _api_is_client_view_client(self, client_uuid): @AddAPI("client.is.logged.in", description="check if a client is logged in") def _api_client_is_logged_in(self, client_uuid): """Check if a client is logged in.""" - return bool( - client_uuid in self.clients and self.clients[client_uuid].state["logged in"] - ) + return bool(client_uuid in self.clients and self.clients[client_uuid].state["logged in"]) @AddAPI("client.logged.in", description="set a client as logged in") def _api_client_logged_in(self, client_uuid): @@ -338,13 +324,10 @@ def _command_show(self): ) banned_clients = [ - {"address": item, "until": self.banned[item].expires} - for item in self.banned + {"address": item, "until": self.banned[item].expires} for item in self.banned ] permbanips = self.api("plugins.core.settings:get")(self.plugin_id, "permbanips") - banned_clients.extend( - {"address": item, "until": "Permanent"} for item in permbanips - ) + banned_clients.extend({"address": item, "until": "Permanent"} for item in permbanips) if banned_clients: banned_clients_columns = [ @@ -396,15 +379,11 @@ def _command_ban(self): tmsg = [] if removed: - tmsg.extend( - (f"Removed {len(removed)} IPs from the ban list", ", ".join(removed)) - ) + tmsg.extend((f"Removed {len(removed)} IPs from the ban list", ", ".join(removed))) if added: tmsg.extend((f"Added {len(added)} IPs to the ban list", ", ".join(added))) if not_found: - tmsg.extend( - (f"{len(not_found)} IPs were not acted on", ", ".join(not_found)) - ) + tmsg.extend((f"{len(not_found)} IPs were not acted on", ", ".join(not_found))) if not tmsg: tmsg = ["No changes made"] diff --git a/src/bastproxy/plugins/core/cmdq/plugin/_cmdq.py b/src/bastproxy/plugins/core/cmdq/plugin/_cmdq.py index d44f3a85..06ff8806 100644 --- a/src/bastproxy/plugins/core/cmdq/plugin/_cmdq.py +++ b/src/bastproxy/plugins/core/cmdq/plugin/_cmdq.py @@ -33,9 +33,7 @@ def _phook_init_plugin(self): def _eventcb_plugin_unloaded(self): """A plugin was unloaded.""" if event_record := self.api("plugins.core.events:get.current.event.record")(): - self.api(f"{self.plugin_id}:remove.mud.commands.for.plugin")( - event_record["plugin_id"] - ) + self.api(f"{self.plugin_id}:remove.mud.commands.for.plugin")(event_record["plugin_id"]) @AddAPI( "remove.mud.commands.for.plugin", @@ -85,9 +83,7 @@ def _api_start(self, cmdtype): @AddAPI("type.add", description="add a command type") def _api_type_add(self, cmdtype, cmd, regex, **kwargs): """Add a command type.""" - owner = self.api("libs.api:get.caller.owner")( - ignore_owner_list=[self.plugin_id] - ) + owner = self.api("libs.api:get.caller.owner")(ignore_owner_list=[self.plugin_id]) beforef = kwargs.get("beforef") afterf = kwargs.get("afterf") if "plugin" in kwargs: @@ -106,25 +102,19 @@ def _api_type_add(self, cmdtype, cmd, regex, **kwargs): self.api("plugins.core.events:add.event")( f"cmd_{self.current_command['ctype']}_send", self.cmds[cmdtype]["owner"], - description=[ - f"event for the command {self.cmds[cmdtype]['ctype']} being sent" - ], + description=[f"event for the command {self.cmds[cmdtype]['ctype']} being sent"], arg_descriptions={"None": None}, ) self.api("plugins.core.events:add.event")( f"cmd_{self.current_command['ctype']}_completed", self.cmds[cmdtype]["owner"], - description=[ - f"event for the command {self.cmds[cmdtype]['ctype']} completing" - ], + description=[f"event for the command {self.cmds[cmdtype]['ctype']} completing"], arg_descriptions={"None": None}, ) def sendnext(self): """Send the next command.""" - LogRecord( - "sendnext - checking queue", level="debug", sources=[self.plugin_id] - )() + LogRecord("sendnext - checking queue", level="debug", sources=[self.plugin_id])() if not self.queue or self.current_command: return @@ -141,9 +131,7 @@ def sendnext(self): self.cmds[cmdtype]["beforef"]() self.current_command = cmdt - self.api("plugins.core.events:raise.event")( - f"cmd_{self.current_command['ctype']}_send" - ) + self.api("plugins.core.events:raise.event")(f"cmd_{self.current_command['ctype']}_send") SendDataDirectlyToMud(NetworkData(cmd), show_in_history=False)() def checkinqueue(self, cmd): @@ -169,9 +157,7 @@ def _api_finish(self, cmdtype): )() self.cmds[cmdtype]["afterf"]() - self.api("libs.timing:timing.finish")( - f"cmd_{self.current_command['ctype']}" - ) + self.api("libs.timing:timing.finish")(f"cmd_{self.current_command['ctype']}") self.api("plugins.core.events:raise.event")( f"cmd_{self.current_command['ctype']}_completed" ) @@ -181,9 +167,7 @@ def _api_finish(self, cmdtype): @AddAPI("queue.add.command", description="add a command to the plugin") def _api_queue_add_command(self, cmdtype, arguments=""): """Add a command to the queue.""" - plugin = self.api("libs.api:get.caller.owner")( - ignore_owner_list=[self.plugin_id] - ) + plugin = self.api("libs.api:get.caller.owner")(ignore_owner_list=[self.plugin_id]) cmd = self.cmds[cmdtype]["cmd"] if arguments: cmd = f"{cmd} {arguments!s}" @@ -208,9 +192,7 @@ def resetqueue(self, _=None): def _command_fixqueue(self): """Finish the last command.""" if self.current_command: - self.api("libs.timing:timing.finish")( - f"cmd_{self.current_command['ctype']}" - ) + self.api("libs.timing:timing.finish")(f"cmd_{self.current_command['ctype']}") self.current_command = {} self.sendnext() diff --git a/src/bastproxy/plugins/core/colors/plugin/_colors.py b/src/bastproxy/plugins/core/colors/plugin/_colors.py index cb5f095e..f2024e41 100644 --- a/src/bastproxy/plugins/core/colors/plugin/_colors.py +++ b/src/bastproxy/plugins/core/colors/plugin/_colors.py @@ -108,9 +108,7 @@ def _api_colorcode_to_html(self, sinput): # line = fixstring(line) if "@@" in stripped_line: stripped_line = stripped_line.replace("@@", "\0") - tlist = re.split( - r"(@[cmyrgbwCMYRGBWD]|@[xz]\d\d\d|@[xz]\d\d|@[xz]\d)", stripped_line - ) + tlist = re.split(r"(@[cmyrgbwCMYRGBWD]|@[xz]\d\d\d|@[xz]\d\d|@[xz]\d)", stripped_line) nlist = [] color = "w" @@ -193,15 +191,11 @@ def _api_colorcode_to_ansicode(self, tstr): if color == "x": tcolor, newtext = re.findall(r"^(\d\d?\d?)(.*)$", text)[0] color = f"38;5;{tcolor}" - tstr2 = tstr2 + self.api(f"{self.plugin_id}:ansicode.to.string")( - color, newtext - ) + tstr2 = tstr2 + self.api(f"{self.plugin_id}:ansicode.to.string")(color, newtext) elif color == "z": tcolor, newtext = re.findall(r"^(\d\d?\d?)(.*)$", text)[0] color = f"48;5;{tcolor}" - tstr2 = tstr2 + self.api(f"{self.plugin_id}:ansicode.to.string")( - color, newtext - ) + tstr2 = tstr2 + self.api(f"{self.plugin_id}:ansicode.to.string")(color, newtext) else: tstr2 = tstr2 + self.api(f"{self.plugin_id}:ansicode.to.string")( CONVERTCOLORS[color], text @@ -211,9 +205,7 @@ def _api_colorcode_to_ansicode(self, tstr): tstr = tstr2 + f"{chr(27)}[0m" return re.sub("\0", "@", tstr) # put @ back in - @AddAPI( - "colorcode.escape", description="escape colorcodes so they are not interpreted" - ) + @AddAPI("colorcode.escape", description="escape colorcodes so they are not interpreted") def _api_colorcode_escape(self, tstr): """Escape colorcodes.""" tinput = tstr.splitlines() diff --git a/src/bastproxy/plugins/core/commands/libs/_command.py b/src/bastproxy/plugins/core/commands/libs/_command.py index 6c256e82..f34fe4af 100644 --- a/src/bastproxy/plugins/core/commands/libs/_command.py +++ b/src/bastproxy/plugins/core/commands/libs/_command.py @@ -43,9 +43,7 @@ def __init__( self.group = group or "" self.preamble = preamble self.show_in_history = show_in_history - self.full_cmd = self.api("plugins.core.commands:get.command.format")( - self.plugin_id, name - ) + self.full_cmd = self.api("plugins.core.commands:get.command.format")(self.plugin_id, name) self.short_help = shelp self.count = 0 self.current_args: CmdArgsRecord | dict = {} @@ -53,9 +51,7 @@ def __init__( self.last_run_start_time: datetime.datetime | None = None self.last_run_end_time: datetime.datetime | None = None - def run( - self, arg_string: str = "", format=False - ) -> tuple[bool | None, list[str], str]: + def run(self, arg_string: str = "", format=False) -> tuple[bool | None, list[str], str]: """Run the command.""" self.last_run_start_time = datetime.datetime.now(datetime.UTC) self.current_arg_string = arg_string @@ -69,9 +65,7 @@ def run( if not success: message.extend(fail_message) - return self.run_finish( - False, message, "could not parse args", format=format - ) + return self.run_finish(False, message, "could not parse args", format=format) args = CmdArgsRecord( f"{self.plugin_id}:{self.name}", @@ -99,9 +93,7 @@ def run( sources=[self.plugin_id, "plugins.core.commands"], exc_info=True, )(actor) - return self.run_finish( - False, message, "function returned False", format=format - ) + return self.run_finish(False, message, "function returned False", format=format) if isinstance(return_value, tuple): retval = return_value[0] @@ -115,9 +107,7 @@ def run( actor = f"{self.plugin_id}:run_command:returned_False" message.append("") message.extend(self.arg_parser.format_help().splitlines()) - return self.run_finish( - False, message, "function returned False", format=format - ) + return self.run_finish(False, message, "function returned False", format=format) return self.run_finish(True, message, "command ran successfully", format=format) @@ -183,30 +173,22 @@ def format_return_message(self, message): The formatted message with appropriate colors and formatting. """ - simple = self.api("plugins.core.settings:get")( - "plugins.core.commands", "simple_output" - ) + simple = self.api("plugins.core.settings:get")("plugins.core.commands", "simple_output") include_date = self.api("plugins.core.settings:get")( "plugins.core.commands", "include_date" ) - line_length = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "linelen" - ) + line_length = self.api("plugins.core.settings:get")("plugins.core.proxy", "linelen") preamble_color = self.api("plugins.core.proxy:preamble.color.get")() header_color = self.api("plugins.core.settings:get")( "plugins.core.commands", "header_color" ) command_indent = self.api("plugins.core.commands:get.command.indent")() command_indent_string = " " * command_indent - command_line_length = self.api( - "plugins.core.commands:get.command.line.length" - )() + command_line_length = self.api("plugins.core.commands:get.command.line.length")() output_indent = self.api("plugins.core.commands:get.output.indent")() output_indent_string = " " * output_indent - command = self.api("plugins.core.commands:get.command.format")( - self.plugin_id, self.name - ) + command = self.api("plugins.core.commands:get.command.format")(self.plugin_id, self.name) newmessage = [ "", @@ -251,9 +233,7 @@ def format_return_message(self, message): "Arguments", "-", command_line_length, filler_color=header_color ) ) - newmessage.extend( - [command_indent_string + line for line in arg_message] - ) + newmessage.extend([command_indent_string + line for line in arg_message]) newmessage.append( command_indent_string diff --git a/src/bastproxy/plugins/core/commands/libs/_utils.py b/src/bastproxy/plugins/core/commands/libs/_utils.py index 09ac807d..6ffd00a8 100644 --- a/src/bastproxy/plugins/core/commands/libs/_utils.py +++ b/src/bastproxy/plugins/core/commands/libs/_utils.py @@ -51,9 +51,7 @@ def __call__(self, func): func.command_data = CommandFuncData() func.command_data.command["kwargs"].update(self.command_kwargs) if "name" not in func.command_data.command: - func.command_data.command["name"] = self.name or func.__name__.replace( - "_command_", " " - ) + func.command_data.command["name"] = self.name or func.__name__.replace("_command_", " ") func.command_data.command["autoadd"] = self.autoadd return func @@ -89,8 +87,6 @@ def __call__(self, func): # insert at 0 because decorators are applied in bottom->top order, # so the last decorator applied will be the first # make it so the order can be exactly like using an argparse object - func.command_data.arguments.insert( - 0, {"args": self.args, "kwargs": self.kwargs} - ) + func.command_data.arguments.insert(0, {"args": self.args, "kwargs": self.kwargs}) return func diff --git a/src/bastproxy/plugins/core/commands/plugin/_commands.py b/src/bastproxy/plugins/core/commands/plugin/_commands.py index 26239b37..b31e574f 100644 --- a/src/bastproxy/plugins/core/commands/plugin/_commands.py +++ b/src/bastproxy/plugins/core/commands/plugin/_commands.py @@ -50,9 +50,7 @@ def _phook_init_plugin(self): # load the history self.history_save_file = self.plugin_info.data_directory / "history.txt" - self.command_history_dict = PersistentDict( - self.plugin_id, self.history_save_file, "c" - ) + self.command_history_dict = PersistentDict(self.plugin_id, self.history_save_file, "c") if "history" not in self.command_history_dict: self.command_history_dict["history"] = [] self.command_history_data = self.command_history_dict["history"] @@ -152,34 +150,24 @@ def _api_get_command_indent(self): """Return the command indent.""" return self.api("plugins.core.settings:get")(self.plugin_id, "command_indent") - @AddAPI( - "get.command.count", description="get a command count for a specific plugin" - ) + @AddAPI("get.command.count", description="get a command count for a specific plugin") def _api_get_command_count(self, plugin_id): """Get the command count for a specific plugin.""" - return ( - len(self.command_data[plugin_id]) if plugin_id in self.command_data else 0 - ) + return len(self.command_data[plugin_id]) if plugin_id in self.command_data else 0 @AddAPI("get.output.indent", description="indent for command output") def _api_get_output_indent(self): """Return the output indent.""" if self.api("plugins.core.settings:get")(self.plugin_id, "simple_output"): - return self.api("plugins.core.settings:get")( - self.plugin_id, "command_indent" - ) + return self.api("plugins.core.settings:get")(self.plugin_id, "command_indent") - return ( - self.api("plugins.core.settings:get")(self.plugin_id, "command_indent") * 2 - ) + return self.api("plugins.core.settings:get")(self.plugin_id, "command_indent") * 2 @AddAPI("format.output.header", description="format a header with the header color") def _api_format_output_header(self, header_text, line_length=None): """Format an output header.""" if line_length is None: - line_length = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "linelen" - ) + line_length = self.api("plugins.core.settings:get")("plugins.core.proxy", "linelen") color = self.api("plugins.core.settings:get")( "plugins.core.commands", "output_header_color" ) @@ -193,9 +181,7 @@ def _api_format_output_header(self, header_text, line_length=None): def _api_format_output_subheader(self, header_text, line_length=None): """Format an output header.""" if line_length is None: - line_length = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "linelen" - ) + line_length = self.api("plugins.core.settings:get")("plugins.core.proxy", "linelen") color = self.api("plugins.core.settings:get")( "plugins.core.commands", "output_subheader_color" ) @@ -206,27 +192,21 @@ def _api_format_output_subheader(self, header_text, line_length=None): def _api_format_header(self, header_text, color, line_length=None): """Format an output header.""" if line_length is None: - line_length = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "linelen" - ) + line_length = self.api("plugins.core.settings:get")("plugins.core.proxy", "linelen") return self._format_output_header(header_text, color, line_length) def _format_output_header(self, header_text, color, line_length=None): """Format an output header.""" if not line_length: - line_length = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "linelen" - ) + line_length = self.api("plugins.core.settings:get")("plugins.core.proxy", "linelen") if not color: color = self.api("plugins.core.settings:get")( "plugins.core.commands", "output_header_color" ) - if _ := self.api("plugins.core.settings:get")( - "plugins.core.commands", "multiline_headers" - ): + if _ := self.api("plugins.core.settings:get")("plugins.core.commands", "multiline_headers"): return [ self.api("plugins.core.utils:cap.line")( f"{'-' * (line_length - 2)}", @@ -257,18 +237,14 @@ def _format_output_header(self, header_text, color, line_length=None): @AddAPI("get.command.line.length", description="get line length for command") def _api_get_command_line_length(self): """Return the line length for output.""" - line_length = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "linelen" - ) + line_length = self.api("plugins.core.settings:get")("plugins.core.proxy", "linelen") command_indent = self.api(f"{self.plugin_id}:get.command.indent")() return line_length - 2 * command_indent @AddAPI("get.output.line.length", description="get line length for command output") def _api_get_output_line_length(self): """Return the line length for output.""" - line_length = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "linelen" - ) + line_length = self.api("plugins.core.settings:get")("plugins.core.proxy", "linelen") output_indent = self.api(f"{self.plugin_id}:get.output.indent")() return line_length - 2 * output_indent @@ -305,9 +281,7 @@ def _add_commands_for_all_plugins(self): @RegisterToEvent(event_name="ev_plugin_loaded") def _eventcb_plugin_loaded(self): """Handle the plugin loaded event.""" - if not ( - event_record := self.api("plugins.core.events:get.current.event.record")() - ): + if not (event_record := self.api("plugins.core.events:get.current.event.record")()): return if not self.api.startup and event_record["plugin_id"] == self.plugin_id: @@ -355,15 +329,11 @@ def get_command_functions_in_object(self, base, recurse=True): ): function_list.append(attr) elif recurse: - function_list.extend( - self.get_command_functions_in_object(attr, recurse=False) - ) + function_list.extend(self.get_command_functions_in_object(attr, recurse=False)) return function_list - @AddAPI( - "add.command.by.func", description="add a command based on a decorated function" - ) + @AddAPI("add.command.by.func", description="add a command based on a decorated function") def _api_add_command_by_func(self, func, force=False): """Add a command based on the new decorator stuff.""" LogRecord( @@ -373,7 +343,9 @@ def _api_add_command_by_func(self, func, force=False): )() if hasattr(func, "__self__"): if hasattr(func.__self__, "name"): - msg = f"func is from plugin {func.__self__.plugin_id} with name {func.__self__.name}" + msg = ( + f"func is from plugin {func.__self__.plugin_id} with name {func.__self__.name}" + ) else: msg = f"func is from plugin {func.__self__.plugin_id}" LogRecord(msg, level="debug", sources=[self.plugin_id])() @@ -421,15 +393,13 @@ def _api_add_command_by_func(self, func, force=False): del command_data.command["kwargs"]["name"] if "description" in command_data.argparse["kwargs"]: - command_data.argparse["kwargs"]["description"] = command_data.argparse[ - "kwargs" - ]["description"].format(**func.__self__.__dict__) + command_data.argparse["kwargs"]["description"] = command_data.argparse["kwargs"][ + "description" + ].format(**func.__self__.__dict__) parser = argp.ArgumentParser(**command_data.argparse["kwargs"]) for arg in command_data.arguments: - arg["kwargs"]["help"] = arg["kwargs"]["help"].format( - **func.__self__.__dict__ - ) + arg["kwargs"]["help"] = arg["kwargs"]["help"].format(**func.__self__.__dict__) parser.add_argument(*arg["args"], **arg["kwargs"]) parser.add_argument("-h", "--help", help="show help", action="store_true") @@ -454,9 +424,7 @@ def _api_add_command_by_func(self, func, force=False): command_kwargs["show_in_history"] = True try: - command = CommandClass( - plugin_id, command_name, func, parser, **command_kwargs - ) + command = CommandClass(plugin_id, command_name, func, parser, **command_kwargs) except Exception: LogRecord( f"Error creating command {func.__name__}: {command_kwargs}", @@ -484,9 +452,7 @@ def _eventcb_plugin_unloaded(self): """ if event_record := self.api("plugins.core.events:get.current.event.record")(): if event_record["plugin_id"] != self.plugin_id: - self.api(f"{self.plugin_id}:remove.data.for.plugin")( - event_record["plugin_id"] - ) + self.api(f"{self.plugin_id}:remove.data.for.plugin")(event_record["plugin_id"]) else: LogRecord( f"{self.plugin_id}_plugin_unloaded: {self.plugin_id} is me!", @@ -494,9 +460,7 @@ def _eventcb_plugin_unloaded(self): sources=[self.plugin_id], )() - @AddAPI( - "get.command.format", description="get a command string formatted for a plugin" - ) + @AddAPI("get.command.format", description="get a command string formatted for a plugin") def _api_get_command_format(self, plugin_id, command): """Return a command string formatted for the plugin. @@ -520,9 +484,7 @@ def _api_set_current_command(self, command): """Set the current command.""" self.current_command = command - @AddAPI( - "remove.data.for.plugin", description="remove all command data for a plugin" - ) + @AddAPI("remove.data.for.plugin", description="remove all command data for a plugin") def _api_remove_data_for_plugin(self, plugin_id): """Remove all command data for a plugin. @@ -538,9 +500,7 @@ def _api_remove_data_for_plugin(self, plugin_id): if self.api("libs.plugins.loader:is.plugin.id")(plugin_id): # remove commands from _command_list that start with plugin_instance.plugin_id new_commands = [ - command - for command in self.commands_list - if not command.startswith(plugin_id) + command for command in self.commands_list if not command.startswith(plugin_id) ] self.commands_list = new_commands @@ -570,9 +530,7 @@ def _api_command_help_format(self, plugin_id, command_name): return "" - @AddAPI( - "get.commands.for.plugin.data", description="get the command data for a plugin" - ) + @AddAPI("get.commands.for.plugin.data", description="get the command data for a plugin") def _api_get_commands_for_plugin_data(self, plugin_id): """Get the data for commands for the specified plugin. @@ -733,17 +691,13 @@ def pass_through_command_from_event(self) -> None: returns the updated event """ - if not ( - event_record := self.api("plugins.core.events:get.current.event.record")() - ): + if not (event_record := self.api("plugins.core.events:get.current.event.record")()): return original_command = event_record["line"].line # if the command is the same as the last command, do antispam checks - if original_command == self.api("plugins.core.settings:get")( - self.plugin_id, "lastcmd" - ): + if original_command == self.api("plugins.core.settings:get")(self.plugin_id, "lastcmd"): self.api("plugins.core.settings:change")( self.plugin_id, "cmdcount", @@ -752,12 +706,10 @@ def pass_through_command_from_event(self) -> None: # if the command has been sent spamcount times, then we send an antispam # command in between - if self.api("plugins.core.settings:get")( - self.plugin_id, "cmdcount" - ) == self.api("plugins.core.settings:get")(self.plugin_id, "spamcount"): - event_record.addupdate( - "Modify", "Antispam Command sent", savedata=False - ) + if self.api("plugins.core.settings:get")(self.plugin_id, "cmdcount") == self.api( + "plugins.core.settings:get" + )(self.plugin_id, "spamcount"): + event_record.addupdate("Modify", "Antispam Command sent", savedata=False) LogRecord( f"sending antspam command: {self.api('plugins.core.settings:get')('plugins.core.commands', 'antispamcommand')}", level="debug", @@ -765,9 +717,7 @@ def pass_through_command_from_event(self) -> None: )() SendDataDirectlyToMud( NetworkData( - self.api("plugins.core.settings:get")( - self.plugin_id, "antispamcommand" - ) + self.api("plugins.core.settings:get")(self.plugin_id, "antispamcommand") ), show_in_history=False, )() @@ -794,9 +744,7 @@ def pass_through_command_from_event(self) -> None: level="debug", sources=[self.plugin_id], )() - self.api("plugins.core.settings:change")( - self.plugin_id, "lastcmd", original_command - ) + self.api("plugins.core.settings:change")(self.plugin_id, "lastcmd", original_command) def proxy_help(self, header, header2, data): """Print the proxy help. @@ -857,9 +805,7 @@ def error_command_not_found_in_plugin(self, plugin_id, command=None): [ *[ line.replace("plugins.", "") - for line in self.api(f"{self.plugin_id}:list.commands.formatted")( - plugin_id - ) + for line in self.api(f"{self.plugin_id}:list.commands.formatted")(plugin_id) ], ] ) @@ -926,9 +872,7 @@ def error_package_not_found(self, package=None): *self.api(f"{self.plugin_id}:format.output.header")("Available Packages"), *[ line.replace("plugins.", "") - for line in self.api("libs.plugins.loader:get.packages.list")( - active_only=True - ) + for line in self.api("libs.plugins.loader:get.packages.list")(active_only=True) ], ] return self.proxy_help("Proxy Help", "", message) @@ -941,9 +885,7 @@ def find_command_only_prefix(self): *self.api(f"{self.plugin_id}:format.output.header")("Available Packages"), *[ line.replace("plugins.", "") - for line in self.api("libs.plugins.loader:get.packages.list")( - active_only=True - ) + for line in self.api("libs.plugins.loader:get.packages.list")(active_only=True) ], ] return self.proxy_help("Proxy Help", "", message) @@ -953,16 +895,14 @@ def find_command( ) -> tuple[CommandClass | None, str, bool, str, list[str]]: """Find a command from the client.""" message: list[str] = [] - LogRecord( - f"find_command: {command_line}", level="debug", sources=[self.plugin_id] - )(actor=f"{self.plugin_id}find_command") + LogRecord(f"find_command: {command_line}", level="debug", sources=[self.plugin_id])( + actor=f"{self.plugin_id}find_command" + ) # copy the command command = command_line - commandprefix = self.api("plugins.core.settings:get")( - self.plugin_id, "cmdprefix" - ) + commandprefix = self.api("plugins.core.settings:get")(self.plugin_id, "cmdprefix") command_str = command if command_str in [ @@ -1001,9 +941,7 @@ def find_command( ) temp_command = command_split[2] if len(command_split) > 2 else "" # try and find the command - command_data = self.api(f"{self.plugin_id}:get.commands.for.plugin.data")( - new_plugin - ) + command_data = self.api(f"{self.plugin_id}:get.commands.for.plugin.data")(new_plugin) command_list = list(command_data.keys()) LogRecord(f"{command_list=}", level="debug", sources=[self.plugin_id])( actor=f"{self.plugin_id}:find_command" @@ -1036,9 +974,7 @@ def find_command_split_command_string( self, command_str: str ) -> tuple[bool, str, str, str, list, str]: """Split a command string into its parts.""" - commandprefix = self.api("plugins.core.settings:get")( - self.plugin_id, "cmdprefix" - ) + commandprefix = self.api("plugins.core.settings:get")(self.plugin_id, "cmdprefix") cmd_args_split = command_str.split(" ", 1) command_str = cmd_args_split[0] @@ -1092,9 +1028,7 @@ def find_command_split_command_string( def run_internal_command_from_event(self): """Run the internal command from the client event.""" - if not ( - event_record := self.api("plugins.core.events:get.current.event.record")() - ): + if not (event_record := self.api("plugins.core.events:get.current.event.record")()): return clients = [event_record["client_id"]] if event_record["client_id"] else None @@ -1106,9 +1040,7 @@ def run_internal_command_from_event(self): if message: if isinstance(message, list): - message = NetworkData( - message, owner_id="run_internal_command_from_event" - ) + message = NetworkData(message, owner_id="run_internal_command_from_event") SendDataDirectlyToClient(message, clients=clients)() if event_record["showinhistory"] != show_in_history: @@ -1145,9 +1077,7 @@ def run_internal_command_from_event(self): if message: if isinstance(message, list): - message = NetworkData( - message, owner_id="run_internal_command_from_event" - ) + message = NetworkData(message, owner_id="run_internal_command_from_event") SendDataDirectlyToClient(message, clients=clients)() @AddAPI("get.summary.data.for.plugin", description="get summary data for a plugin") @@ -1169,9 +1099,7 @@ def _api_get_summary_data_for_plugin(self, plugin_id): ) ) - command_data = self.api(f"{self.plugin_id}:get.commands.for.plugin.data")( - plugin_id - ) + command_data = self.api(f"{self.plugin_id}:get.commands.for.plugin.data")(plugin_id) commands = [command.name for command in command_data.values()] @@ -1187,14 +1115,10 @@ def _api_get_summary_data_for_plugin(self, plugin_id): return msg - @AddAPI( - "get.detailed.data.for.plugin", description="get detailed data for a plugin" - ) + @AddAPI("get.detailed.data.for.plugin", description="get detailed data for a plugin") def _api_get_detailed_data_for_plugin(self, plugin_id): """Get detailed data for a plugin.""" - command_data = self.api(f"{self.plugin_id}:get.commands.for.plugin.data")( - plugin_id - ) + command_data = self.api(f"{self.plugin_id}:get.commands.for.plugin.data")(plugin_id) if not command_data: return ["None"] @@ -1228,13 +1152,9 @@ def _eventcb_check_for_command(self) -> None: if it is, the command is parsed and executed and the output sent to the client. """ - commandprefix = self.api("plugins.core.settings:get")( - self.plugin_id, "cmdprefix" - ) + commandprefix = self.api("plugins.core.settings:get")(self.plugin_id, "cmdprefix") - if not ( - event_record := self.api("plugins.core.events:get.current.event.record")() - ): + if not (event_record := self.api("plugins.core.events:get.current.event.record")()): return if self.current_input_event == event_record: @@ -1381,25 +1301,19 @@ def _api_list_commands_formatted(self, plugin_id): message.extend(self.format_command_list(groups[group])) message.append("") - message.append( - output_subheader_color + "-" * 5 + " " + "Base" + " " + "-" * 5 + "@w" - ) + message.append(output_subheader_color + "-" * 5 + " " + "Base" + " " + "-" * 5 + "@w") message.extend(self.format_command_list(groups["Base"])) message.append("") return message - @AddAPI( - "list.plugins.formatted", description="return a formatted list of all plugins" - ) + @AddAPI("list.plugins.formatted", description="return a formatted list of all plugins") def _api_list_plugins_formatted(self, package=None): """List all plugins.""" if package: plugin_id_list = [ item.replace("plugins.", "") - for item in self.api("libs.plugins.loader:get.plugins.in.package")( - package - ) + for item in self.api("libs.plugins.loader:get.plugins.in.package")(package) ] else: plugin_id_list = [ @@ -1437,20 +1351,14 @@ def _command_list(self, _=None): plugin_id = args["plugin"] if not self.api("libs.plugins.loader:is.plugin.id")(plugin_id): - return True, self.api(f"{self.plugin_id}:list.plugins.formatted")( - package=plugin_id - ) + return True, self.api(f"{self.plugin_id}:list.plugins.formatted")(package=plugin_id) if plugin_commands := self.command_data[plugin_id]: if command and command in plugin_commands: - help_message = ( - plugin_commands[command].arg_parser.format_help().splitlines() - ) + help_message = plugin_commands[command].arg_parser.format_help().splitlines() message.extend(help_message) else: - message.extend( - self.api(f"{self.plugin_id}:list.commands.formatted")(plugin_id) - ) + message.extend(self.api(f"{self.plugin_id}:list.commands.formatted")(plugin_id)) else: message.append(f"There are no commands in plugin {plugin_id}") @@ -1485,9 +1393,7 @@ def _command_run_history(self): command = self.command_history_data[args["number"]] SendDataDirectlyToClient( - NetworkData( - [f"Commands: rerunning command {command}"], owner_id="run_history" - ) + NetworkData([f"Commands: rerunning command {command}"], owner_id="run_history") )(f"{self.plugin_id}:_command_run_history") SendDataDirectlyToClient(NetworkData([f"{command}"], owner_id="run_history"))() @@ -1512,8 +1418,7 @@ def _command_history(self): message.append("Command history cleared") else: message.extend( - f"{self.command_history_data.index(i)} : {i}" - for i in self.command_history_data + f"{self.command_history_data.index(i)} : {i}" for i in self.command_history_data ) return True, message diff --git a/src/bastproxy/plugins/core/errors/plugin/_errors.py b/src/bastproxy/plugins/core/errors/plugin/_errors.py index 7d32b93c..0b72d52e 100644 --- a/src/bastproxy/plugins/core/errors/plugin/_errors.py +++ b/src/bastproxy/plugins/core/errors/plugin/_errors.py @@ -63,9 +63,7 @@ def _api_clear_all_errors(self): self.errors = [] @AddParser(description="show errors") - @AddArgument( - "number", help="list the last errors", default="-1", nargs="?" - ) + @AddArgument("number", help="list the last errors", default="-1", nargs="?") def _command_show(self): """@G%(name)s@w - @B%(cmdname)s@w. @@ -86,9 +84,7 @@ def _command_show(self): msg.extend(("", f"Time : {i['timestamp']}", f"Error : {i['msg']}")) else: for i in errors: - msg.extend( - ("", f"Time : {i['timestamp']}", f"Error : {i['msg']}") - ) + msg.extend(("", f"Time : {i['timestamp']}", f"Error : {i['msg']}")) else: msg.append("There are no errors") return True, msg diff --git a/src/bastproxy/plugins/core/events/libs/_event.py b/src/bastproxy/plugins/core/events/libs/_event.py index 2f816aa2..38b70ce8 100644 --- a/src/bastproxy/plugins/core/events/libs/_event.py +++ b/src/bastproxy/plugins/core/events/libs/_event.py @@ -67,15 +67,12 @@ def count(self) -> int: def isregistered(self, func) -> bool: """Check if a function is registered to this event.""" return any( - func in self.priority_dictionary[priority] - for priority in self.priority_dictionary + func in self.priority_dictionary[priority] for priority in self.priority_dictionary ) def isempty(self) -> bool: """Check if an event has no functions registered.""" - return not any( - self.priority_dictionary[priority] for priority in self.priority_dictionary - ) + return not any(self.priority_dictionary[priority] for priority in self.priority_dictionary) def register(self, func: Callable, func_owner_id: str, prio: int = 50) -> bool: """Register a function to this event container.""" @@ -141,9 +138,7 @@ def removeowner(self, owner_id): if call_back.owner_id == owner_id ) for event_function in plugins_to_unregister: - self.api("plugins.core.events:unregister.from.event")( - self.name, event_function.func - ) + self.api("plugins.core.events:unregister.from.event")(self.name, event_function.func) def detail(self) -> list[str]: """Format a detail of the event.""" @@ -195,8 +190,7 @@ def detail(self) -> list[str]: ) if self.arg_descriptions and "None" not in self.arg_descriptions: message.extend( - f"@C{arg:<13}@w : {self.arg_descriptions[arg]}" - for arg in self.arg_descriptions + f"@C{arg:<13}@w : {self.arg_descriptions[arg]}" for arg in self.arg_descriptions ) elif "None" in self.arg_descriptions: message.append("None") diff --git a/src/bastproxy/plugins/core/events/libs/process/_raisedevent.py b/src/bastproxy/plugins/core/events/libs/process/_raisedevent.py index f31d3a51..b42b92d3 100644 --- a/src/bastproxy/plugins/core/events/libs/process/_raisedevent.py +++ b/src/bastproxy/plugins/core/events/libs/process/_raisedevent.py @@ -57,9 +57,7 @@ def _exec_once(self, actor, **kwargs): log_savestate = self.api("plugins.core.settings:get")( "plugins.core.events", "log_savestate" ) - log: bool = ( - True if log_savestate else not self.event_name.endswith("_savestate") - ) + log: bool = True if log_savestate else not self.event_name.endswith("_savestate") if log: LogRecord( diff --git a/src/bastproxy/plugins/core/events/plugin/_event.py b/src/bastproxy/plugins/core/events/plugin/_event.py index 7baef816..7ac55863 100644 --- a/src/bastproxy/plugins/core/events/plugin/_event.py +++ b/src/bastproxy/plugins/core/events/plugin/_event.py @@ -66,15 +66,12 @@ def count(self) -> int: def isregistered(self, func) -> bool: """Check if a function is registered to this event.""" return any( - func in self.priority_dictionary[priority] - for priority in self.priority_dictionary + func in self.priority_dictionary[priority] for priority in self.priority_dictionary ) def isempty(self) -> bool: """Check if an event has no functions registered.""" - return not any( - self.priority_dictionary[priority] for priority in self.priority_dictionary - ) + return not any(self.priority_dictionary[priority] for priority in self.priority_dictionary) def register(self, func: Callable, func_owner_id: str, prio: int = 50) -> bool: """Register a function to this event container.""" @@ -140,9 +137,7 @@ def removeowner(self, owner_id): if call_back.owner_id == owner_id ) for event_function in plugins_to_unregister: - self.api("plugins.core.events:unregister.from.event")( - self.name, event_function.func - ) + self.api("plugins.core.events:unregister.from.event")(self.name, event_function.func) def detail(self) -> list[str]: """Format a detail of the event.""" @@ -194,8 +189,7 @@ def detail(self) -> list[str]: ) if self.arg_descriptions and "None" not in self.arg_descriptions: message.extend( - f"@C{arg:<13}@w : {self.arg_descriptions[arg]}" - for arg in self.arg_descriptions + f"@C{arg:<13}@w : {self.arg_descriptions[arg]}" for arg in self.arg_descriptions ) elif "None" in self.arg_descriptions: message.append("None") diff --git a/src/bastproxy/plugins/core/events/plugin/_events.py b/src/bastproxy/plugins/core/events/plugin/_events.py index 1f6eb312..96a33134 100644 --- a/src/bastproxy/plugins/core/events/plugin/_events.py +++ b/src/bastproxy/plugins/core/events/plugin/_events.py @@ -58,9 +58,7 @@ def _phook_initialize(self): def _eventcb_post_startup_plugins_loaded(self): """Register all events in all plugins.""" self._register_all_plugin_events() - self.api("plugins.core.events:raise.event")( - f"ev_{self.plugin_id}_all_events_registered" - ) + self.api("plugins.core.events:raise.event")(f"ev_{self.plugin_id}_all_events_registered") @RegisterToEvent(event_name="ev_baseplugin_patched") def _eventcb_baseplugin_patched(self): @@ -93,9 +91,7 @@ def _register_all_plugin_events(self): def _register_events_for_plugin(self, plugin_id): """Register all events in a plugin.""" plugin_instance = self.api("libs.plugins.loader:get.plugin.instance")(plugin_id) - event_functions = self.get_event_registration_functions_in_object( - plugin_instance - ) + event_functions = self.get_event_registration_functions_in_object(plugin_instance) LogRecord( f"_register_events_for_plugin: {plugin_id} has {len(event_functions)} registrations", level="debug", @@ -157,9 +153,7 @@ def _api_register_event_by_func(self, func): event_name = item["event_name"] event_name = event_name.format(**func.__self__.__dict__) prio = item["priority"] - self.api("plugins.core.events:register.to.event")( - event_name, func, priority=prio - ) + self.api("plugins.core.events:register.to.event")(event_name, func, priority=prio) @RegisterToEvent(event_name="ev_plugin_unloaded") def _eventcb_plugin_unloaded(self): @@ -170,9 +164,7 @@ def _eventcb_plugin_unloaded(self): level="info", sources=[self.plugin_id, event_record["plugin_id"]], )() - self.api(f"{self.plugin_id}:remove.events.for.owner")( - event_record["plugin_id"] - ) + self.api(f"{self.plugin_id}:remove.events.for.owner")(event_record["plugin_id"]) @AddAPI("get.current.event.name", description="return the current event name") def _api_get_current_event_name(self): @@ -314,9 +306,7 @@ def _api_remove_events_for_owner(self, owner_id): for event in self.events: self.events[event].removeowner(owner_id) - @AddAPI( - "get.detailed.data.for.plugin", description="get detailed data for a plugin" - ) + @AddAPI("get.detailed.data.for.plugin", description="get detailed data for a plugin") def _api_get_detailed_data_for_plugin(self, owner_name): """Return all events for an owner.""" owner_events = {} @@ -372,9 +362,7 @@ def _api_raise_event( event_args = {} if not calledfrom: - calledfrom = self.api("libs.api:get.caller.owner")( - ignore_owner_list=[self.plugin_id] - ) + calledfrom = self.api("libs.api:get.caller.owner")(ignore_owner_list=[self.plugin_id]) if not calledfrom: LogRecord( @@ -436,9 +424,7 @@ def _command_raise(self): return True, message @AddParser(description="get details of an event") - @AddArgument( - "event", help="the event name to get details for", default=[], nargs="*" - ) + @AddArgument("event", help="the event name to get details for", default=[], nargs="*") def _command_detail(self): """@G%(name)s@w - @B%(cmdname)s@w. @@ -450,9 +436,7 @@ def _command_detail(self): message = [] if args["event"]: for event_name in args["event"]: - message.extend( - self.api(f"{self.plugin_id}:get.event.detail")(event_name) - ) + message.extend(self.api(f"{self.plugin_id}:get.event.detail")(event_name)) message.append("") else: message.append("Please provide an event name") @@ -518,13 +502,10 @@ def _command_list(self): eventlist = [ name for name in eventnames - if not self.events[name].description - or not self.events[name].arg_descriptions + if not self.events[name].description or not self.events[name].arg_descriptions ] elif show_raised_only: - eventlist = [ - name for name in eventnames if self.events[name].raised_count > 0 - ] + eventlist = [name for name in eventnames if self.events[name].raised_count > 0] else: eventlist = eventnames diff --git a/src/bastproxy/plugins/core/events/plugin/_process_event.py b/src/bastproxy/plugins/core/events/plugin/_process_event.py index f31d3a51..b42b92d3 100644 --- a/src/bastproxy/plugins/core/events/plugin/_process_event.py +++ b/src/bastproxy/plugins/core/events/plugin/_process_event.py @@ -57,9 +57,7 @@ def _exec_once(self, actor, **kwargs): log_savestate = self.api("plugins.core.settings:get")( "plugins.core.events", "log_savestate" ) - log: bool = ( - True if log_savestate else not self.event_name.endswith("_savestate") - ) + log: bool = True if log_savestate else not self.event_name.endswith("_savestate") if log: LogRecord( diff --git a/src/bastproxy/plugins/core/fuzzy/plugin/_fuzzy.py b/src/bastproxy/plugins/core/fuzzy/plugin/_fuzzy.py index 49852945..dd79d3e0 100644 --- a/src/bastproxy/plugins/core/fuzzy/plugin/_fuzzy.py +++ b/src/bastproxy/plugins/core/fuzzy/plugin/_fuzzy.py @@ -91,9 +91,7 @@ def _api_get_best_match( )() else: sorted_extract = sort_fuzzy_result( - rapidfuzz.process.extract( - item_to_match, list_to_match, scorer=scorer_inst - ) + rapidfuzz.process.extract(item_to_match, list_to_match, scorer=scorer_inst) ) LogRecord( f"_api_get_best_match - extract for {item_to_match} - {sorted_extract}", diff --git a/src/bastproxy/plugins/core/log/libs/_custom_logger.py b/src/bastproxy/plugins/core/log/libs/_custom_logger.py index c734faf3..d251b08c 100644 --- a/src/bastproxy/plugins/core/log/libs/_custom_logger.py +++ b/src/bastproxy/plugins/core/log/libs/_custom_logger.py @@ -79,9 +79,7 @@ def format(self, record: logging.LogRecord): if record.name != "root": if "exc_info" in record.__dict__ and record.exc_info: formatted_exc = traceback.format_exception(record.exc_info[1]) - formatted_exc_no_newline = [ - line.rstrip() for line in formatted_exc if line - ] + formatted_exc_no_newline = [line.rstrip() for line in formatted_exc if line] if isinstance(record.msg, LogRecord): record.msg.extend(formatted_exc_no_newline) record.msg.addupdate("Modify", "add traceback") @@ -118,9 +116,7 @@ def emit(self, record): try: canlog = bool( not self.api("libs.api:has")("plugins.core.log:can.log.to.console") - or self.api("plugins.core.log:can.log.to.console")( - record.name, record.levelno - ) + or self.api("plugins.core.log:can.log.to.console")(record.name, record.levelno) ) if isinstance(record.msg, LogRecord): if canlog and not record.msg.wasemitted["console"]: @@ -144,9 +140,7 @@ def __init__( utc=False, atTime=None, ): - super().__init__( - filename, when, interval, backupCount, encoding, delay, utc, atTime - ) + super().__init__(filename, when, interval, backupCount, encoding, delay, utc, atTime) self.api = API(owner_id=f"{__name__}:CustomRotatingFileHandler") self.setLevel(logging.DEBUG) @@ -155,9 +149,7 @@ def emit(self, record): try: canlog = bool( not self.api("libs.api:has")("plugins.core.log:can.log.to.file") - or self.api("plugins.core.log:can.log.to.file")( - record.name, record.levelno - ) + or self.api("plugins.core.log:can.log.to.file")(record.name, record.levelno) ) if isinstance(record.msg, LogRecord): if canlog and not record.msg.wasemitted["file"]: @@ -184,9 +176,7 @@ def emit(self, record): update_type_counts(record.name, record.levelno) - canlog = self.api("plugins.core.log:can.log.to.client")( - record.name, record.levelno - ) + canlog = self.api("plugins.core.log:can.log.to.client")(record.name, record.levelno) if canlog or record.levelno >= logging.ERROR: formatted_message = self.format(record) if isinstance(record.msg, LogRecord): @@ -195,9 +185,7 @@ def emit(self, record): else: color = None if not record.msg.wasemitted["client"]: - new_message = NetworkData( - owner_id=f"{__name__}:CustomClientHandler:emit" - ) + new_message = NetworkData(owner_id=f"{__name__}:CustomClientHandler:emit") [ new_message.append(NetworkDataLine(line, color=color or "")) for line in formatted_message.splitlines() @@ -237,13 +225,9 @@ def setup_loggers(log_level: int): default_log_file_path = API.BASEDATALOGPATH / default_log_file (API.BASEDATALOGPATH / "networkdata").mkdir(parents=True, exist_ok=True) - data_logger_log_file_path = ( - API.BASEDATALOGPATH / "networkdata" / data_logger_log_file - ) + data_logger_log_file_path = API.BASEDATALOGPATH / "networkdata" / data_logger_log_file - file_handler = CustomRotatingFileHandler( - filename=default_log_file_path, when="midnight" - ) + file_handler = CustomRotatingFileHandler(filename=default_log_file_path, when="midnight") file_handler.formatter = logging.Formatter( "%(asctime)s : %(levelname)-9s - %(name)-22s - %(message)s" ) diff --git a/src/bastproxy/plugins/core/log/libs/tz.py b/src/bastproxy/plugins/core/log/libs/tz.py index 68688b5f..5198d0fe 100644 --- a/src/bastproxy/plugins/core/log/libs/tz.py +++ b/src/bastproxy/plugins/core/log/libs/tz.py @@ -26,11 +26,7 @@ def formatTime_RFC3339_UTC(self, record, datefmt=None): str: The formatted timestamp in RFC3339 format with UTC timezone. """ - return ( - datetime.datetime.fromtimestamp(record.created) - .astimezone(datetime.UTC) - .isoformat() - ) + return datetime.datetime.fromtimestamp(record.created).astimezone(datetime.UTC).isoformat() def formatTime_RFC3339(self, record, datefmt=None): diff --git a/src/bastproxy/plugins/core/log/plugin/_log.py b/src/bastproxy/plugins/core/log/plugin/_log.py index d025d3ed..065e6886 100644 --- a/src/bastproxy/plugins/core/log/plugin/_log.py +++ b/src/bastproxy/plugins/core/log/plugin/_log.py @@ -56,9 +56,7 @@ def _phook_log_post_init_custom_logging( self, ): # pyright: ignore[reportInvalidTypeVarUse] # setup file logging and network data logging - LogRecord( - "setting up custom logging", level="debug", sources=[self.plugin_id] - )() + LogRecord("setting up custom logging", level="debug", sources=[self.plugin_id])() setup_loggers(logging.DEBUG) @RegisterPluginHook("initialize") @@ -105,45 +103,33 @@ def _api_get_level_color(self, level): match level: case "error": try: - return self.api("plugins.core.settings:get")( - self.plugin_id, "color_error" - ) + return self.api("plugins.core.settings:get")(self.plugin_id, "color_error") except Exception: return "@x136" case "warning": try: - return self.api("plugins.core.settings:get")( - self.plugin_id, "color_warning" - ) + return self.api("plugins.core.settings:get")(self.plugin_id, "color_warning") except Exception: return "@y" case "info": try: - return self.api("plugins.core.settings:get")( - self.plugin_id, "color_info" - ) + return self.api("plugins.core.settings:get")(self.plugin_id, "color_info") except Exception: return "@w" case "debug": try: - return self.api("plugins.core.settings:get")( - self.plugin_id, "color_debug" - ) + return self.api("plugins.core.settings:get")(self.plugin_id, "color_debug") except Exception: return "@x246" case "critical": try: - return self.api("plugins.core.settings:get")( - self.plugin_id, "color_critical" - ) + return self.api("plugins.core.settings:get")(self.plugin_id, "color_critical") except Exception: return "@r" case _: return "" - @AddAPI( - "can.log.to.console", description="check if a logger can log to the console" - ) + @AddAPI("can.log.to.console", description="check if a logger can log to the console") def _api_can_log_to_console(self, logger, level): """Check if a logger can log to the console. @@ -154,9 +140,7 @@ def _api_can_log_to_console(self, logger, level): self.handlers["console"][logger_name] = "info" self.handlers["console"].sync() - convlevel = getattr( - logging, self.handlers["console"][logger_name].upper(), logging.INFO - ) + convlevel = getattr(logging, self.handlers["console"][logger_name].upper(), logging.INFO) return level >= convlevel @AddAPI("can.log.to.file", description="check if a logger can log to file") @@ -170,9 +154,7 @@ def _api_can_log_to_file(self, logger, level): self.handlers["file"][logger_name] = "info" self.handlers["file"].sync() - convlevel = getattr( - logging, self.handlers["file"][logger_name].upper(), logging.INFO - ) + convlevel = getattr(logging, self.handlers["file"][logger_name].upper(), logging.INFO) return level >= convlevel @AddAPI("can.log.to.client", description="check if a logger can log to the client") @@ -263,9 +245,7 @@ def _command_client(self): logger_name, level=level, flag=remove ) if logger_name in self.handlers["client"]: - tmsg.append( - f"setting {logger_name} to log to client at level {level}" - ) + tmsg.append(f"setting {logger_name} to log to client at level {level}") else: tmsg.append(f"no longer sending {logger_name} to client") @@ -337,19 +317,11 @@ def _command_console(self): if logger_name not in self.handlers["console"]: self.handlers["console"][logger_name] = "info" if self.handlers["console"][logger_name] == "info": - self.api(f"{self.plugin_id}:set.log.to.console")( - logger_name, level="debug" - ) - tmsg.append( - f"setting {logger_name} to log to console at level 'debug'" - ) + self.api(f"{self.plugin_id}:set.log.to.console")(logger_name, level="debug") + tmsg.append(f"setting {logger_name} to log to console at level 'debug'") else: - self.api(f"{self.plugin_id}:set.log.to.console")( - logger_name, level="info" - ) - tmsg.append( - f"setting {logger_name} to log to console at default level 'info'" - ) + self.api(f"{self.plugin_id}:set.log.to.console")(logger_name, level="info") + tmsg.append(f"setting {logger_name} to log to console at default level 'info'") self.handlers["console"].sync() return True, tmsg @@ -420,19 +392,11 @@ def _command_file(self): if logger_name not in self.handlers["file"]: self.handlers["file"][logger_name] = "info" if self.handlers["file"][logger_name] == "info": - self.api(f"{self.plugin_id}:set.log.to.file")( - logger_name, level="debug" - ) - tmsg.append( - f"setting {logger_name} to log to file at level 'debug'" - ) + self.api(f"{self.plugin_id}:set.log.to.file")(logger_name, level="debug") + tmsg.append(f"setting {logger_name} to log to file at level 'debug'") else: - self.api(f"{self.plugin_id}:set.log.to.file")( - logger_name, level="info" - ) - tmsg.append( - f"setting {logger_name} to log to file at default level 'info'" - ) + self.api(f"{self.plugin_id}:set.log.to.file")(logger_name, level="info") + tmsg.append(f"setting {logger_name} to log to file at default level 'info'") self.handlers["file"].sync() return True, tmsg diff --git a/src/bastproxy/plugins/core/pluginm/plugin/_pluginm.py b/src/bastproxy/plugins/core/pluginm/plugin/_pluginm.py index c3cd8044..6735f99a 100644 --- a/src/bastproxy/plugins/core/pluginm/plugin/_pluginm.py +++ b/src/bastproxy/plugins/core/pluginm/plugin/_pluginm.py @@ -82,9 +82,7 @@ def _command_helper_format_plugin_list( ) if foundrequired and required_color_line: - msg.extend( - ("", f"* {required_color}Required plugins appear in this color@w") - ) + msg.extend(("", f"* {required_color}Required plugins appear in this color@w")) return msg # get a message of plugins in a package @@ -109,8 +107,7 @@ def _get_package_plugins(self, package): if plist := [ plugin_id for plugin_id in loaded_plugins_by_id - if self.api("libs.plugins.loader:get.plugin.info")(plugin_id).package - == package + if self.api("libs.plugins.loader:get.plugin.info")(plugin_id).package == package ]: plugins = sorted(plist) mod = __import__(package) @@ -166,13 +163,9 @@ def _get_changed_plugins(self): if list_to_format := [ plugin_id for plugin_id in loaded_plugins_by_id - if self.api("libs.plugins.loader:get.plugin.info")( - plugin_id - ).get_changed_files() + if self.api("libs.plugins.loader:get.plugin.info")(plugin_id).get_changed_files() ]: - msg = self._command_helper_format_plugin_list( - list_to_format, "Changed Plugins" - ) + msg = self._command_helper_format_plugin_list(list_to_format, "Changed Plugins") return msg or ["No plugins are changed on disk."] @@ -183,9 +176,7 @@ def _get_invalid_plugins(self): if list_to_format := [ plugin_id for plugin_id in all_plugins_by_id - if self.api("libs.plugins.loader:get.plugin.info")( - plugin_id - ).get_invalid_python_files() + if self.api("libs.plugins.loader:get.plugin.info")(plugin_id).get_invalid_python_files() ]: msg = self._command_helper_format_plugin_list( list_to_format, "Plugins with invalid python code" @@ -198,9 +189,7 @@ def _get_not_loaded_plugins(self): """Create a message of all not loaded plugins.""" msg = [] not_loaded_plugins = self.api("libs.plugins.loader:get.not.loaded.plugins")() - msg = self._command_helper_format_plugin_list( - not_loaded_plugins, "Not Loaded Plugins" - ) + msg = self._command_helper_format_plugin_list(not_loaded_plugins, "Not Loaded Plugins") return msg or ["There are no plugins that are not loaded"] @@ -223,9 +212,7 @@ def _get_not_loaded_plugins(self): help="list plugins that have files with invalid python code", action="store_true", ) - @AddArgument( - "package", help="the package of the plugins to list", default="", nargs="?" - ) + @AddArgument("package", help="the package of the plugins to list", default="", nargs="?") def _command_list(self): """@G%(name)s@w - @B%(cmdname)s@w. @@ -295,9 +282,7 @@ def _load_other_plugins_after_core_and_client_plugins(self): if plugins_to_load_setting: LogRecord("Loading other plugins", level="info", sources=[self.plugin_id])() self.api("libs.plugins.loader:load.plugins")(plugins_to_load_setting) - LogRecord( - "Finished loading other plugins", level="info", sources=[self.plugin_id] - )() + LogRecord("Finished loading other plugins", level="info", sources=[self.plugin_id])() @RegisterToEvent(event_name="ev_plugins.core.proxy_shutdown") def _eventcb_shutdown(self, _=None): @@ -337,9 +322,7 @@ def _command_load(self): if self.api("libs.plugins.loader:is.plugin.loaded")(plugin_id): return True, [f"Plugin {plugin_id} is already loaded"] - not_loaded_plugins_by_id = self.api( - "libs.plugins.loader:get.not.loaded.plugins" - )() + not_loaded_plugins_by_id = self.api("libs.plugins.loader:get.not.loaded.plugins")() if plugin_id and plugin_id not in not_loaded_plugins_by_id: return True, [f"Plugin {plugin_id} not found"] @@ -448,9 +431,7 @@ def _command_reload(self): ) return True, [f"{args['plugin']} reloaded"] - @RegisterToEvent( - event_name="ev_plugins.core.events_all_events_registered", priority=1 - ) + @RegisterToEvent(event_name="ev_plugins.core.events_all_events_registered", priority=1) def _eventcb_all_events_registered(self): """This resends all the different plugin initialization events,. diff --git a/src/bastproxy/plugins/core/proxy/plugin/_proxy.py b/src/bastproxy/plugins/core/proxy/plugin/_proxy.py index bea08a2e..0a24e0a5 100644 --- a/src/bastproxy/plugins/core/proxy/plugin/_proxy.py +++ b/src/bastproxy/plugins/core/proxy/plugin/_proxy.py @@ -61,9 +61,7 @@ def _phook_init_plugin(self): @RegisterPluginHook("initialize") def _phook_initialize(self): """Initialize the plugin.""" - restartproxymessage = ( - "@RPlease restart the proxy for the changes to take effect.@w" - ) + restartproxymessage = "@RPlease restart the proxy for the changes to take effect.@w" # Network Settings self.api("plugins.core.settings:add")( @@ -195,9 +193,7 @@ def _api_preamble_get(self): def _api_preamble_color_get(self, error=False): """Get the preamble color.""" if error: - return self.api("plugins.core.settings:get")( - self.plugin_id, "preambleerrorcolor" - ) + return self.api("plugins.core.settings:get")(self.plugin_id, "preambleerrorcolor") return self.api("plugins.core.settings:get")(self.plugin_id, "preamblecolor") @RegisterToEvent(event_name="ev_libs.net.mud_mudconnect") @@ -208,9 +204,7 @@ def _eventcb_sendusernameandpw(self): """ if self.api("plugins.core.settings:get")(self.plugin_id, "username") != "": SendDataDirectlyToMud( - NetworkData( - self.api("plugins.core.settings:get")(self.plugin_id, "username") - ), + NetworkData(self.api("plugins.core.settings:get")(self.plugin_id, "username")), show_in_history=False, )() pasw = self.api(f"{self.plugin_id}:ssc.mudpw")() @@ -245,9 +239,7 @@ def _command_info(self): template % ( "Connected", - self.mud_connection.connected_time.strftime( - self.api.time_format - ), + self.mud_connection.connected_time.strftime(self.api.time_format), ), template % ( @@ -269,9 +261,7 @@ def _command_info(self): tmsg.append("") - _, nmsg = self.api("plugins.core.commands:run")( - "plugins.core.clients", "show", "" - ) + _, nmsg = self.api("plugins.core.commands:run")("plugins.core.clients", "show", "") tmsg.extend(nmsg) return True, tmsg @@ -297,9 +287,7 @@ def _command_connect(self): self.api("plugins.core.settings:get")(self.plugin_id, "mudport"), ) - self.api("libs.asynch:task.add")( - self.mud_connection.connect_to_mud, "Mud Connect Task" - ) + self.api("libs.asynch:task.add")(self.mud_connection.connect_to_mud, "Mud Connect Task") return True, ["Connecting to the mud"] @@ -342,13 +330,9 @@ def _command_shutdown(self): def _command_restart(self): """Restart the proxy.""" args = self.api("plugins.core.commands:get.current.command.args")() - listen_port = self.api("plugins.core.settings:get")( - self.plugin_id, "listenport" - ) + listen_port = self.api("plugins.core.settings:get")(self.plugin_id, "listenport") self.api(f"{self.plugin_id}:restart")(args["seconds"]) - return True, [ - f"Restarting bastproxy on port: {listen_port} in {args['seconds']} seconds" - ] + return True, [f"Restarting bastproxy on port: {listen_port} in {args['seconds']} seconds"] @RegisterToEvent(event_name="ev_plugins.core.clients_client_logged_in") def _eventcb_client_logged_in(self): @@ -396,10 +380,7 @@ def _eventcb_client_logged_in(self): f"{self.api('plugins.core.commands:get.command.format')(self.plugin_id, 'proxypw')} 'This is a password'", ) ) - if ( - self.api(f"{self.plugin_id}:ssc.proxypwview")(quiet=True) - == "defaultviewpass" - ): + if self.api(f"{self.plugin_id}:ssc.proxypwview")(quiet=True) == "defaultviewpass": tmsg.extend( ( divider, @@ -414,24 +395,18 @@ def _eventcb_client_logged_in(self): tmsg.insert(0, divider) if tmsg: - new_message = NetworkData( - tmsg, owner_id=f"{self.plugin_id}:_eventcb_client_logged_in" - ) + new_message = NetworkData(tmsg, owner_id=f"{self.plugin_id}:_eventcb_client_logged_in") SendDataDirectlyToClient(new_message)() @AddAPI("restart", description="restart the proxy") def _api_restart(self, restart_in=None): """Restart the proxy after 10 seconds.""" restart_in = restart_in or 10 - listen_port = self.api("plugins.core.settings:get")( - self.plugin_id, "listenport" - ) + listen_port = self.api("plugins.core.settings:get")(self.plugin_id, "listenport") SendDataDirectlyToClient( NetworkData( - [ - f"Restarting bastproxy on port: {listen_port} in {restart_in} seconds" - ], + [f"Restarting bastproxy on port: {listen_port} in {restart_in} seconds"], owner_id=f"{self.plugin_id}:_api_restart", ) )() @@ -482,9 +457,7 @@ def _eventcb_command_seperator_change(self): action="store_false", default=True, ) - @AddArgument( - "-p", "--ports", help="show network ports", action="store_false", default=True - ) + @AddArgument("-p", "--ports", help="show network ports", action="store_false", default=True) def _command_resource(self): """Output proxy resource usage.""" args = self.api("plugins.core.commands:get.current.command.args")() diff --git a/src/bastproxy/plugins/core/settings/_patch_base.py b/src/bastproxy/plugins/core/settings/_patch_base.py index 8e8fceda..f5e54f8d 100644 --- a/src/bastproxy/plugins/core/settings/_patch_base.py +++ b/src/bastproxy/plugins/core/settings/_patch_base.py @@ -43,9 +43,7 @@ def _command_settings_plugin_set(self): args = self.api("plugins.core.commands:get.current.command.args")() if not args["name"] or args["name"] == "list": - return True, self.api("plugins.core.settings:get.all.settings.formatted")( - self.plugin_id - ) + return True, self.api("plugins.core.settings:get.all.settings.formatted")(self.plugin_id) arg_string = f"-p {self.plugin_id} {args.arg_string}" _retval, return_string = self.api("plugins.core.commands:run")( diff --git a/src/bastproxy/plugins/core/settings/plugin/_settings.py b/src/bastproxy/plugins/core/settings/plugin/_settings.py index 84a99e08..b0bfd60e 100644 --- a/src/bastproxy/plugins/core/settings/plugin/_settings.py +++ b/src/bastproxy/plugins/core/settings/plugin/_settings.py @@ -92,9 +92,7 @@ def _api_add(self, plugin_id, setting_name, default, stype, help, **kwargs): if plugin_id not in self.settings_values: data_directory = self.api(f"{plugin_id}:get.data.directory")() settings_file: Path = data_directory / "settingvalues.txt" - self.settings_values[plugin_id] = PersistentDict( - plugin_id, settings_file, "c" - ) + self.settings_values[plugin_id] = PersistentDict(plugin_id, settings_file, "c") if setting_name not in self.settings_values[plugin_id]: self.settings_values[plugin_id][setting_name] = setting_info.default @@ -151,16 +149,12 @@ def _eventcb_plugin_reset(self): self.api(f"{self.plugin_id}:reset")(plugin_id) event_record["plugins_that_acted"].append(self.plugin_id) - @AddAPI( - "reset", description="reset all settings for a plugin to their default values" - ) + @AddAPI("reset", description="reset all settings for a plugin to their default values") def _api_reset(self, plugin_id): """Reset all settings for a plugin to their default values.""" self.settings_values[plugin_id].clear() for i in self.settings_info[plugin_id]: - self.settings_values[plugin_id][i] = self.settings_info[plugin_id][ - i - ].default + self.settings_values[plugin_id][i] = self.settings_info[plugin_id][i].default self.settings_values[plugin_id].sync() @AddAPI("change", description="change the value of a setting") @@ -229,9 +223,7 @@ def _api_get_setting_info(self, plugin_id, setting): return None return self.settings_info[plugin_id][setting] - @AddAPI( - "initialize.plugin.settings", description="initialize the settings for a plugin" - ) + @AddAPI("initialize.plugin.settings", description="initialize the settings for a plugin") def _api_initialize_plugin_settings(self, plugin_id): """Initialize the settings for a plugin.""" LogRecord( @@ -311,9 +303,7 @@ def _api_raise_event_all_settings(self, plugin_id): event_args={"var": i, "newvalue": new_value, "oldvalue": old_value}, ) - @AddAPI( - "is.setting.hidden", description="check if a plugin setting is flagged hidden" - ) + @AddAPI("is.setting.hidden", description="check if a plugin setting is flagged hidden") def _api_is_setting_hidden(self, plugin_id, setting): """Check if a plugin setting is hidden.""" return self.settings_info[plugin_id][setting].hidden @@ -332,9 +322,7 @@ def _api_add_setting_to_map(self, plugin_id, setting_name): def format_setting_for_print(self, plugin_id, setting_name): """Format a setting for printing.""" value = self.api(f"{self.plugin_id}:get")(plugin_id, setting_name) - setting_info = self.api(f"{self.plugin_id}:get.setting.info")( - plugin_id, setting_name - ) + setting_info = self.api(f"{self.plugin_id}:get.setting.info")(plugin_id, setting_name) if setting_info.nocolor: value = value.replace("@", "@@") elif setting_info.stype == "color": @@ -373,9 +361,7 @@ def _api_get_all_settings_formatted(self, plugin_id): return [f"There are no settings defined in {plugin_id}"] for i in self.settings_info[plugin_id]: - if formatted_setting := self.api("plugins.core.settings:format.setting")( - plugin_id, i - ): + if formatted_setting := self.api("plugins.core.settings:format.setting")(plugin_id, i): tmsg.append(formatted_setting) return tmsg or [f"There are no settings defined in {plugin_id}"] @@ -410,16 +396,14 @@ def _command_list(self): if args["plugin"]: default_message = [f"No settings found for {args['plugin']}"] if self.api("libs.plugins.loader:is.plugin.loaded")(args["plugin"]): - settings[args["plugin"]] = self.api( - f"{self.plugin_id}:get.all.for.plugin" - )(args["plugin"]) + settings[args["plugin"]] = self.api(f"{self.plugin_id}:get.all.for.plugin")( + args["plugin"] + ) else: return True, ["Plugin does not exist"] else: for plugin_id in loaded_plugins: - settings[plugin_id] = self.api(f"{self.plugin_id}:get.all.for.plugin")( - plugin_id - ) + settings[plugin_id] = self.api(f"{self.plugin_id}:get.all.for.plugin")(plugin_id) if not settings: return True, ["No settings found"] @@ -459,9 +443,7 @@ def _command_list(self): ) @AddArgument("name", help="the setting name", default="list", nargs="?") @AddArgument("value", help="the new value of the setting", default="", nargs="?") - @AddArgument( - "-p", "--plugin", help="the plugin of the setting", default="", nargs="?" - ) + @AddArgument("-p", "--plugin", help="the plugin of the setting", default="", nargs="?") def _command_settings_plugin_sets(self): """Command to set a plugin setting.""" args = self.api("plugins.core.commands:get.current.command.args")() diff --git a/src/bastproxy/plugins/core/sqldb/libs/_sqlite.py b/src/bastproxy/plugins/core/sqldb/libs/_sqlite.py index 7bb7bc4c..5c05e845 100644 --- a/src/bastproxy/plugins/core/sqldb/libs/_sqlite.py +++ b/src/bastproxy/plugins/core/sqldb/libs/_sqlite.py @@ -146,9 +146,7 @@ def _command_dbselect(self): message = [] if args: if sqlstmt := args["stmt"]: - results = self.api(f"{self.plugin_id}:{self.database_name}.select")( - sqlstmt - ) + results = self.api(f"{self.plugin_id}:{self.database_name}.select")(sqlstmt) message.extend(f"{i}" for i in results) else: message.append("Please enter a select statement") @@ -171,9 +169,7 @@ def _command_dbtables(self): desc = cursor.fetchall() cursor.close() message.extend((f"Fields in table {args['table']}:", "-" * 40)) - message.extend( - f"{item['name']:<25} : {item['type']}" for item in desc - ) + message.extend(f"{item['name']:<25} : {item['type']}" for item in desc) return True, message else: tables = [] @@ -188,9 +184,7 @@ def _command_dbtables(self): cursor.close() if tables: message.append(f"Tables in database {self.database_name}:") - message.extend( - f"{item}" for item in tables if item != "sqlite_sequence" - ) + message.extend(f"{item}" for item in tables if item != "sqlite_sequence") else: message.append(f"No tables in database {self.database_name}") return True, message @@ -228,9 +222,7 @@ def _command_dbclose(self): @AddCommand(group="DB") @AddParser(description="remove a row from a table") - @AddArgument( - "table", help="the table to remove the row from", default="", nargs="?" - ) + @AddArgument("table", help="the table to remove the row from", default="", nargs="?") @AddArgument("rownumber", help="the row number to remove", default=-1, nargs="?") def _command_dbremove(self): """Remove a table from the database.""" @@ -260,13 +252,9 @@ def _command_dbbackup(self): backup_file_name = self.backup_template % name + ".zip" if self.backupdb(name): - message.append( - f"backed up {self.database_name} with name {backup_file_name}" - ) + message.append(f"backed up {self.database_name} with name {backup_file_name}") else: - message.append( - f"could not back up {self.database_name} with name {backup_file_name}" - ) + message.append(f"could not back up {self.database_name} with name {backup_file_name}") return True, message @@ -338,10 +326,7 @@ def getcolumnumnsfromsql(self, tablename): column = sql_line_split_list[0] columnumns.append(column) columnumnsbykeys[column] = True - if ( - "default" in sql_line_split_list - or "Default" in sql_line_split_list - ): + if "default" in sql_line_split_list or "Default" in sql_line_split_list: columnumn_defaults[column] = sql_line_split_list[-1].strip(",") else: columnumn_defaults[column] = None @@ -371,23 +356,16 @@ def converttoinsert(self, tablename, keynull=False, replace=False): temp_list = [f":{i}" for i in columns] columnstring = ", ".join(temp_list) if replace: - sql_string = ( - f"INSERT OR REPLACE INTO {tablename} VALUES ({columnstring})" - ) + sql_string = f"INSERT OR REPLACE INTO {tablename} VALUES ({columnstring})" else: sql_string = f"INSERT INTO {tablename} VALUES ({columnstring})" if keynull and self.tables[tablename]["keyfield"]: - sql_string = sql_string.replace( - f":{self.tables[tablename]['keyfield']}", "NULL" - ) + sql_string = sql_string.replace(f":{self.tables[tablename]['keyfield']}", "NULL") return sql_string def checkcolumnumnexists(self, table, columnumnname): """Check if a columnumn exists.""" - return ( - table in self.tables - and columnumnname in self.tables[table]["columnumnsbykeys"] - ) + return table in self.tables and columnumnname in self.tables[table]["columnumnsbykeys"] def converttoupdate(self, tablename, wherekey="", nokey=None): """Create an update statement based on the columnumns of a table.""" @@ -402,9 +380,7 @@ def converttoupdate(self, tablename, wherekey="", nokey=None): if column != wherekey and (not nokey or column not in nokey) ] columnstring = ",".join(sql_string_list) - sql_string = ( - f"UPDATE {tablename} SET {columnstring} WHERE {wherekey} = :{wherekey};" - ) + sql_string = f"UPDATE {tablename} SET {columnstring} WHERE {wherekey} = :{wherekey};" return sql_string def getversion(self): @@ -608,9 +584,7 @@ def getlast(self, table_name, num, where=""): if where: sql_string = f"SELECT * FROM {table_name} WHERE {where} ORDER by {column_id_name} desc limit {num}" else: - sql_string = ( - f"SELECT * FROM {table_name} ORDER by {column_id_name} desc limit {num}" - ) + sql_string = f"SELECT * FROM {table_name} ORDER by {column_id_name} desc limit {num}" return self.api(f"{self.plugin_id}:{self.database_name}.select")(sql_string) diff --git a/src/bastproxy/plugins/core/timers/plugin/_timers.py b/src/bastproxy/plugins/core/timers/plugin/_timers.py index d350ad30..4a0ff92c 100644 --- a/src/bastproxy/plugins/core/timers/plugin/_timers.py +++ b/src/bastproxy/plugins/core/timers/plugin/_timers.py @@ -72,9 +72,7 @@ def get_first_fire(self) -> datetime.datetime: new_date = now + datetime.timedelta(seconds=self.seconds) if self.time: hour_minute = time.strptime(self.time, "%H%M") - new_date = now.replace( - hour=hour_minute.tm_hour, minute=hour_minute.tm_min, second=0 - ) + new_date = now.replace(hour=hour_minute.tm_hour, minute=hour_minute.tm_min, second=0) while new_date < now: new_date = new_date + datetime.timedelta(days=1) @@ -126,9 +124,7 @@ def _phook_initialize(self): )() # setup the task to check for timers to fire - self.api("libs.asynch:task.add")( - self.check_for_timers_to_fire, "Timer Plugin task" - ) + self.api("libs.asynch:task.add")(self.check_for_timers_to_fire, "Timer Plugin task") @RegisterToEvent(event_name="ev_plugin_unloaded") def _eventcb_plugin_unloaded(self): @@ -139,9 +135,7 @@ def _eventcb_plugin_unloaded(self): "debug", sources=[self.plugin_id, event_record["plugin_id"]], )() - self.api(f"{self.plugin_id}:remove.data.for.plugin")( - event_record["plugin_id"] - ) + self.api(f"{self.plugin_id}:remove.data.for.plugin")(event_record["plugin_id"]) @AddParser(description="toggle log flag for a timer") @AddArgument("timername", help="the timer name", default="", nargs="?") @@ -150,9 +144,7 @@ def _command_log(self) -> tuple[bool, list[str]]: args = self.api("plugins.core.commands:get.current.command.args")() message: list[str] = [] if args["timername"] in self.timer_lookup: - self.timer_lookup[args["timername"]].log = not self.timer_lookup[ - args["timername"] - ].log + self.timer_lookup[args["timername"]].log = not self.timer_lookup[args["timername"]].log message.append( f"changed log flag to {self.timer_lookup[args['timername']].log} for timer {args['timername']}" ) @@ -188,17 +180,11 @@ def _eventcb_timers_ev_plugins_stats(self): @RegisterToEvent(event_name="ev_plugin_stats") def _eventcb_event_get_stats_for_plugin(self) -> None: """Get stats for a plugin.""" - if not ( - event_record := self.api("plugins.core.events:get.current.event.record")() - ): + if not (event_record := self.api("plugins.core.events:get.current.event.record")()): return plugin_id = event_record["plugin_id"] - timers = [ - timer.name - for timer in self.timer_lookup.values() - if timer.owner_id == plugin_id - ] + timers = [timer.name for timer in self.timer_lookup.values() if timer.owner_id == plugin_id] if not timers: return @@ -237,8 +223,7 @@ def _command_list(self) -> tuple[bool, list[str]]: templatestring = "%-20s : %-25s %-9s %-8s %s" message.extend( ( - templatestring - % ("Name", "Defined in", "Enabled", "Fired", "Next Fire"), + templatestring % ("Name", "Defined in", "Enabled", "Fired", "Next Fire"), output_header_color + "-" * 80 + "@w", ) ) @@ -259,9 +244,7 @@ def _command_list(self) -> tuple[bool, list[str]]: return True, message @AddParser(description="get details for a timer") - @AddArgument( - "timers", help="a list of timers to get details", default=None, nargs="*" - ) + @AddArgument("timers", help="a list of timers to get details", default=None, nargs="*") def _command_detail(self) -> tuple[bool, list[str]]: """@G%(name)s@w - @B%(cmdname)s@w. @@ -292,9 +275,7 @@ def _command_detail(self) -> tuple[bool, list[str]]: last_fire_time = timer.last_fired_datetime.strftime( "%a %b %d %Y %H:%M:%S %Z" ) - message.append( - f"{'Last Fire':<{columnwidth}} : {last_fire_time}" - ) + message.append(f"{'Last Fire':<{columnwidth}} : {last_fire_time}") message.extend( ( f"{'Next Fire':<{columnwidth}} : {timer.next_fire_datetime.strftime('%a %b %d %Y %H:%M:%S %Z')}", @@ -332,9 +313,7 @@ def _api_get_timer_next_fire(self, name: str) -> datetime.datetime | None: return None @AddAPI("add.timer", description="add a timer") - def _api_add_timer( - self, name: str, func: Callable, seconds: int, **kwargs - ) -> Timer | None: + def _api_add_timer(self, name: str, func: Callable, seconds: int, **kwargs) -> Timer | None: """Add a timer. @Yname@w = The timer name @@ -349,9 +328,7 @@ def _api_add_timer( returns an Event instance """ - plugin_id: str = self.api("libs.api:get.caller.owner")( - ignore_owner_list=[self.plugin_id] - ) + plugin_id: str = self.api("libs.api:get.caller.owner")(ignore_owner_list=[self.plugin_id]) if "plugin_id" in kwargs: plugin_instance = self.api("libs.plugins.loader:get.plugin.instance")( @@ -409,9 +386,7 @@ def _api_remove_data_for_plugin(self, name: str): this function returns no values """ plugin_instance = self.api("libs.plugins.loader:get.plugin.instance")(name) - LogRecord( - f"removing timers for {name}", level="debug", sources=[self.plugin_id, name] - )() + LogRecord(f"removing timers for {name}", level="debug", sources=[self.plugin_id, name])() timers_to_remove: list[str] = [ i for i in self.timer_lookup @@ -469,10 +444,7 @@ def _add_timer_internal(self, timer: Timer): def _remove_timer_internal(self, timer: Timer): """Internally remove a timer.""" timer_next_time_to_fire = math.floor(timer.next_fire_datetime.timestamp()) - if ( - timer_next_time_to_fire != -1 - and timer_next_time_to_fire in self.timer_events - ): + if timer_next_time_to_fire != -1 and timer_next_time_to_fire in self.timer_events: if timer in self.timer_events[timer_next_time_to_fire]: self.timer_events[timer_next_time_to_fire].remove(timer) if not self.timer_events[timer_next_time_to_fire]: diff --git a/src/bastproxy/plugins/core/triggers/plugin/_triggers.py b/src/bastproxy/plugins/core/triggers/plugin/_triggers.py index c057dc1b..6cfd1141 100644 --- a/src/bastproxy/plugins/core/triggers/plugin/_triggers.py +++ b/src/bastproxy/plugins/core/triggers/plugin/_triggers.py @@ -70,9 +70,7 @@ def raisetrigger(self, args): args["trigger_name"] = self.trigger_name args["trigger_id"] = self.trigger_id - args = self.api("plugins.core.events:raise.event")( - self.event_name, event_args=args - ) + args = self.api("plugins.core.events:raise.event")(self.event_name, event_args=args) LogRecord( f"raisetrigger - trigger {self.trigger_id} raised event {self.event_name} with args {args}", level="debug", @@ -123,12 +121,8 @@ def _phook_initialize(self): self.plugin_id, "enabled", "True", bool, "enable triggers" ) - self.api("plugins.core.triggers:trigger.add")( - "beall", None, self.plugin_id, enabled=False - ) - self.api("plugins.core.triggers:trigger.add")( - "all", None, self.plugin_id, enabled=False - ) + self.api("plugins.core.triggers:trigger.add")("beall", None, self.plugin_id, enabled=False) + self.api("plugins.core.triggers:trigger.add")("all", None, self.plugin_id, enabled=False) self.api("plugins.core.triggers:trigger.add")( "emptyline", None, self.plugin_id, enabled=False ) @@ -153,9 +147,7 @@ def _eventcb_enabled_modify(self): def _eventcb_plugin_unloaded(self): """A plugin was unloaded.""" if event_record := self.api("plugins.core.events:get.current.event.record")(): - self.api(f"{self.plugin_id}:remove.data.for.owner")( - event_record["plugin_id"] - ) + self.api(f"{self.plugin_id}:remove.data.for.owner")(event_record["plugin_id"]) def rebuild_regexes(self): """Rebuild a regex for priority. @@ -197,9 +189,7 @@ def create_regex_id(self): def _api_trigger_register(self, trigger_name, function, **kwargs): """Register a function to a trigger.""" if trigger_name not in self.triggers: - owner_id = self.api("libs.api:get.caller.owner")( - ignore_owner_list=[self.plugin_id] - ) + owner_id = self.api("libs.api:get.caller.owner")(ignore_owner_list=[self.plugin_id]) trigger_name = self.create_trigger_id(trigger_name, owner_id) return self.api("plugins.core.events:register.to.event")( self.triggers[trigger_name].event_name, function, *kwargs @@ -209,9 +199,7 @@ def _api_trigger_register(self, trigger_name, function, **kwargs): def _api_trigger_unregister(self, trigger_name, function): """Unregister a function from a trigger.""" if trigger_name not in self.triggers: - owner_id = self.api("libs.api:get.caller.owner")( - ignore_owner_list=[self.plugin_id] - ) + owner_id = self.api("libs.api:get.caller.owner")(ignore_owner_list=[self.plugin_id]) trigger_name = self.create_trigger_id(trigger_name, owner_id) return self.api("plugins.core.events:unregister.from.event")( self.triggers[trigger_name].event_name, function @@ -220,9 +208,7 @@ def _api_trigger_unregister(self, trigger_name, function): @AddAPI("trigger.update", description="update a trigger without deleting it") def _api_trigger_update(self, trigger_name, trigger_data): """Update a trigger without deleting it.""" - owner_id = self.api("libs.api:get.caller.owner")( - ignore_owner_list=[self.plugin_id] - ) + owner_id = self.api("libs.api:get.caller.owner")(ignore_owner_list=[self.plugin_id]) trigger_id = self.create_trigger_id(trigger_name, owner_id) if trigger_id not in self.triggers: @@ -250,9 +236,7 @@ def _api_trigger_update(self, trigger_name, trigger_data): self.triggers[trigger_id].original_regex = orig_regex try: - self.triggers[trigger_id].original_regex_compiled = re.compile( - orig_regex - ) + self.triggers[trigger_id].original_regex_compiled = re.compile(orig_regex) except Exception: # pylint: disable=broad-except LogRecord( f"Could not compile regex for trigger: {trigger_name} : {orig_regex}", @@ -279,9 +263,7 @@ def _api_trigger_update(self, trigger_name, trigger_data): self.trigger_groups[old_value].remove(trigger_name) if self.triggers[trigger_name].group not in self.trigger_groups: self.trigger_groups[self.triggers[trigger_name].group] = [] - self.trigger_groups[self.triggers[trigger_name].group].append( - trigger_name - ) + self.trigger_groups[self.triggers[trigger_name].group].append(trigger_name) return None def find_regex_id(self, regex): @@ -323,9 +305,7 @@ def _api_trigger_add(self, trigger_name, regex, owner_id=None, **kwargs): # pyl this function returns no values """ if not owner_id: - owner_id = self.api("libs.api:get.caller.owner")( - ignore_owner_list=[self.plugin_id] - ) + owner_id = self.api("libs.api:get.caller.owner")(ignore_owner_list=[self.plugin_id]) if not owner_id: print("could not add a owner for trigger name", trigger_name) @@ -420,9 +400,7 @@ def _api_trigger_remove(self, trigger_name, force=False, owner_id=None): False if it wasn't """ if not owner_id: - owner_id = self.api("libs.api:get:caller:owner")( - ignore_owner_list=[self.plugin_id] - ) + owner_id = self.api("libs.api:get:caller:owner")(ignore_owner_list=[self.plugin_id]) if not owner_id: LogRecord( @@ -490,16 +468,12 @@ def _api_trigger_get(self, trigger_name, owner_id=None): @Ytrigger_name@w = The trigger name. """ if not owner_id: - owner_id = self.api("libs.api:get.caller.owner")( - ignore_owner_list=[self.plugin_id] - ) + owner_id = self.api("libs.api:get.caller.owner")(ignore_owner_list=[self.plugin_id]) trigger_id = self.create_trigger_id(trigger_name, owner_id) return self.triggers.get(trigger_id, None) - @AddAPI( - "remove.data.for.owner", description="remove all triggers related to a owner" - ) + @AddAPI("remove.data.for.owner", description="remove all triggers related to a owner") def _api_remove_data_for_owner(self, owner_id): """Remove all triggers related to a owner. @@ -530,9 +504,7 @@ def _api_trigger_toggle_enable(self, trigger_name, flag, owner_id=None): this function returns no values """ if not owner_id: - owner_id = self.api("libs.api:get.caller.owner")( - ignore_owner_list=[self.plugin_id] - ) + owner_id = self.api("libs.api:get.caller.owner")(ignore_owner_list=[self.plugin_id]) trigger_id = self.create_trigger_id(trigger_name, owner_id) if trigger_id in self.triggers: @@ -565,9 +537,7 @@ def _api_trigger_toggle_omit(self, trigger_name, flag, owner_id=None): this function returns no values """ if not owner_id: - owner_id = self.api("libs.api:get.caller.owner")( - ignore_owner_list=[self.plugin_id] - ) + owner_id = self.api("libs.api:get.caller.owner")(ignore_owner_list=[self.plugin_id]) trigger_id = self.create_trigger_id(trigger_name, owner_id) if trigger_id in self.triggers: @@ -641,9 +611,9 @@ def process_match(self, data_line, regex_match_data): if self.triggers[trigger_id].argtypes: for arg in self.triggers[trigger_id].argtypes: if arg in group_dict: - group_dict[arg] = self.triggers[trigger_id].argtypes[ - arg - ](group_dict[arg]) + group_dict[arg] = self.triggers[trigger_id].argtypes[arg]( + group_dict[arg] + ) args["matches"] = group_dict self.triggers[trigger_id].raisetrigger(args) if self.triggers[trigger_id].stopevaluating: @@ -652,9 +622,7 @@ def process_match(self, data_line, regex_match_data): @RegisterToEvent(event_name="ev_to_client_data_modify") def _eventcb_check_trigger(self): # pylint: disable=too-many-branches """Check a line of text from the mud to see if it matches any triggers.""" - if not ( - event_record := self.api("plugins.core.events:get.current.event.record")() - ): + if not (event_record := self.api("plugins.core.events:get.current.event.record")()): return # don't check internal data @@ -674,9 +642,7 @@ def _eventcb_check_trigger(self): # pylint: disable=too-many-branches match_data = None if match_data: - match_groups = { - k: v for k, v in match_data.groupdict().items() if v is not None - } + match_groups = {k: v for k, v in match_data.groupdict().items() if v is not None} else: match_groups = {} @@ -812,9 +778,7 @@ def _command_detail(self): ) if self.api("plugins.core.events:has.event")(event_name): message.extend(("", "Event Details:")) - event_details = self.api( - "plugins.core.events:get.event.detail" - )(event_name) + event_details = self.api("plugins.core.events:get.event.detail")(event_name) message.extend(event_details) else: message.extend(["", "No functions registered for this trigger"]) diff --git a/src/bastproxy/plugins/core/utils/plugin/_utils.py b/src/bastproxy/plugins/core/utils/plugin/_utils.py index ffd0f6e4..1a4bff08 100644 --- a/src/bastproxy/plugins/core/utils/plugin/_utils.py +++ b/src/bastproxy/plugins/core/utils/plugin/_utils.py @@ -64,9 +64,7 @@ def _api_format_list_into_columns(self, obj, cols=4, columnwise=True, gap=4): number_of_columns = min(cols, len(list_of_strings)) max_len = max(len(item) for item in list_of_strings) if columnwise: - number_of_columns = math.ceil( - float(len(list_of_strings)) / float(number_of_columns) - ) + number_of_columns = math.ceil(float(len(list_of_strings)) / float(number_of_columns)) plist = [ list_of_strings[i : i + number_of_columns] for i in range(0, len(list_of_strings), number_of_columns) @@ -300,9 +298,7 @@ def _api_center_colored_string( converted_colors_string = self.api("plugins.core.colors:colorcode.to.ansicode")( string_to_center ) - noncolored_string = self.api("plugins.core.colors:ansicode.strip")( - converted_colors_string - ) + noncolored_string = self.api("plugins.core.colors:ansicode.strip")(converted_colors_string) caplength = 0 if endcaps: @@ -316,7 +312,9 @@ def _api_center_colored_string( filler = filler_character * half_length filler_color_end = "@w" if filler_color else "" - new_str = f"{filler_color}{filler}{filler_color_end} {string_to_center} {filler_color}{filler}" + new_str = ( + f"{filler_color}{filler}{filler_color_end} {string_to_center} {filler_color}{filler}" + ) new_length = (half_length * 2) + noncolored_string_length @@ -355,9 +353,7 @@ def _api_check_list_for_match(self, arg, item_list: list[str]) -> list[str]: return matches["partofstring"] - @AddAPI( - "convert.timelength.to.secs", description="converts a time length to seconds" - ) + @AddAPI("convert.timelength.to.secs", description="converts a time length to seconds") def _api_convert_timelength_to_secs(self, timel): """Converts a time length to seconds. @@ -381,31 +377,19 @@ def _api_convert_timelength_to_secs(self, timel): return None days = timelength_match_groups["days"] - converted_days = ( - int(days[:-1]) if days.endswith("d") else int(days) if days else 0 - ) + converted_days = int(days[:-1]) if days.endswith("d") else int(days) if days else 0 hours = timelength_match_groups["hours"] - converted_hours = ( - int(hours[:-1]) if hours.endswith("h") else int(hours) if hours else 0 - ) + converted_hours = int(hours[:-1]) if hours.endswith("h") else int(hours) if hours else 0 minutes = timelength_match_groups["minutes"] converted_minutes = ( - int(minutes[:-1]) - if minutes.endswith("m") - else int(minutes) - if minutes - else 0 + int(minutes[:-1]) if minutes.endswith("m") else int(minutes) if minutes else 0 ) seconds = timelength_match_groups["seconds"] converted_seconds = ( - int(seconds[:-1]) - if seconds.endswith("s") - else int(seconds) - if seconds - else 0 + int(seconds[:-1]) if seconds.endswith("s") else int(seconds) if seconds else 0 ) return ( @@ -427,9 +411,7 @@ def _api_convert_data_to_output_table(self, table_name, data, columns, color="") 'key' : dictionary key, 'width' : the width of the column. """ - line_length_default = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "linelen" - ) + line_length_default = self.api("plugins.core.settings:get")("plugins.core.proxy", "linelen") output_color = color or self.api("plugins.core.settings:get")( "plugins.core.commands", "output_subheader_color" ) @@ -440,9 +422,7 @@ def _api_convert_data_to_output_table(self, table_name, data, columns, color="") column["width"] = max(len(str(item[column["key"]])) for item in temp_data) # build the template string - template_strings = [ - "{" + item["key"] + ":<" + str(item["width"]) + "}" for item in columns - ] + template_strings = ["{" + item["key"] + ":<" + str(item["width"]) + "}" for item in columns] template_string = f"{f' {output_color}|{color_end} '.join(template_strings)}" # build the header dict @@ -454,10 +434,7 @@ def _api_convert_data_to_output_table(self, table_name, data, columns, color="") largest_line = max( [ len(self.api("plugins.core.colors:colorcode.strip")(subheader_msg)), - *[ - len(self.api("plugins.core.colors:colorcode.strip")(line)) - for line in data_msg - ], + *[len(self.api("plugins.core.colors:colorcode.strip")(line)) for line in data_msg], ] ) @@ -508,9 +485,7 @@ def _api_cap_line( spacechar = " " if space else "" if not line_length: - line_length = self.api("plugins.core.settings:get")( - "plugins.core.proxy", "linelen" - ) + line_length = self.api("plugins.core.settings:get")("plugins.core.proxy", "linelen") capchar_len = len(capchar + spacechar) diff --git a/src/bastproxy/plugins/core/watch/plugin/_watch.py b/src/bastproxy/plugins/core/watch/plugin/_watch.py index 1060b4e7..4581411e 100644 --- a/src/bastproxy/plugins/core/watch/plugin/_watch.py +++ b/src/bastproxy/plugins/core/watch/plugin/_watch.py @@ -38,9 +38,7 @@ def _eventcb_plugin_unloaded(self): level="debug", sources=[self.plugin_id, event_record["plugin_id"]], )() - self.api(f"{self.plugin_id}:remove.all.data.for.plugin")( - event_record["plugin_id"] - ) + self.api(f"{self.plugin_id}:remove.all.data.for.plugin")(event_record["plugin_id"]) @AddParser(description="list watches") @AddArgument( @@ -80,9 +78,7 @@ def _command_detail(self): for watch in args["watch"]: if watch in self.watch_data: event_name = self.watch_data[watch]["event_name"] - watch_event = self.api("plugins.core.events:get.event.detail")( - event_name - ) + watch_event = self.api("plugins.core.events:get.event.detail")(event_name) message.extend( ( f"{'Name':<{columnwidth}} : {watch}", @@ -112,9 +108,7 @@ def _api_watch_add(self, watch_name, regex, owner=None, **kwargs): this function returns no values """ if not owner: - owner = self.api("libs.api:get.caller.owner")( - ignore_owner_list=[self.plugin_id] - ) + owner = self.api("libs.api:get.caller.owner")(ignore_owner_list=[self.plugin_id]) if not owner: LogRecord( @@ -222,17 +216,13 @@ def _api_remove_all_data_for_plugin(self, plugin): @RegisterToEvent(event_name="ev_to_mud_data_modify") def _eventcb_check_command(self): """Check input from the client and see if we are watching for it.""" - if not ( - event_record := self.api("plugins.core.events:get.current.event.record")() - ): + if not (event_record := self.api("plugins.core.events:get.current.event.record")()): return client_data = event_record["line"] for watch_name in self.watch_data: cmdre = self.watch_data[watch_name]["compiled"] if match_data := cmdre.match(client_data): - self.watch_data[watch_name]["hits"] = ( - self.watch_data[watch_name]["hits"] + 1 - ) + self.watch_data[watch_name]["hits"] = self.watch_data[watch_name]["hits"] + 1 match_args = { "matched": match_data.groupdict(), "cmdname": f"cmd_{watch_name}", diff --git a/src/bastproxy/plugins/debug/api/plugin/_api.py b/src/bastproxy/plugins/debug/api/plugin/_api.py index 618eec19..5e7b7ff3 100644 --- a/src/bastproxy/plugins/debug/api/plugin/_api.py +++ b/src/bastproxy/plugins/debug/api/plugin/_api.py @@ -20,9 +20,7 @@ class APIPlugin(BasePlugin): """a plugin to show api information.""" @AddParser(description="detail a function in the API") - @AddArgument( - "-a", "--api", help="the api to detail (optional)", default="", nargs="?" - ) + @AddArgument("-a", "--api", help="the api to detail (optional)", default="", nargs="?") @AddArgument("-s", "--stats", help="add stats", action="store_true") @AddArgument( "-sd", @@ -37,9 +35,7 @@ class APIPlugin(BasePlugin): help="use an API that is not from a plugin", action="store_true", ) - @AddArgument( - "-c", "--show-code", help="show the function code", action="store_true" - ) + @AddArgument("-c", "--show-code", help="show the function code", action="store_true") def _command_detail(self): """@G%(name)s@w - @B%(cmdname)s@w. @@ -68,9 +64,7 @@ def _command_detail(self): return True, tmsg @AddParser(description="list functions in the API") - @AddArgument( - "toplevel", help="the top level api to show (optional)", default="", nargs="?" - ) + @AddArgument("toplevel", help="the top level api to show (optional)", default="", nargs="?") @AddArgument( "-np", "--noplugin", @@ -107,9 +101,7 @@ def _command_list(self): return True, tmsg @AddParser(description="call an API") - @AddArgument( - "-a", "--api", help="the api to detail (optional)", default="", nargs="?" - ) + @AddArgument("-a", "--api", help="the api to detail (optional)", default="", nargs="?") @AddArgument("arguments", help="arguments to the api", default="", nargs="*") def _command_call(self): """@G%(name)s@w - @B%(cmdname)s@w. diff --git a/src/bastproxy/plugins/debug/plugins/plugin/_plugins.py b/src/bastproxy/plugins/debug/plugins/plugin/_plugins.py index e070a7d9..0e63429d 100644 --- a/src/bastproxy/plugins/debug/plugins/plugin/_plugins.py +++ b/src/bastproxy/plugins/debug/plugins/plugin/_plugins.py @@ -65,9 +65,7 @@ def _command_hooks(self): ) for hook in hooks: - tmsg.extend( - self.api("plugins.core.commands:format.output.subheader")(f"{hook}") - ) + tmsg.extend(self.api("plugins.core.commands:format.output.subheader")(f"{hook}")) priorities = hooks[hook].keys() priorities = sorted(priorities) for priority in priorities: diff --git a/src/bastproxy/plugins/debug/records/plugin/_records.py b/src/bastproxy/plugins/debug/records/plugin/_records.py index f0deb511..09bcaad9 100644 --- a/src/bastproxy/plugins/debug/records/plugin/_records.py +++ b/src/bastproxy/plugins/debug/records/plugin/_records.py @@ -132,14 +132,10 @@ def _command_detail(self): if update := record.get_update(args["update"]): tmsg.extend(update.format_detailed()) else: - tmsg.append( - f"update {args['update']} in record {args['uid']} not found" - ) + tmsg.append(f"update {args['update']} in record {args['uid']} not found") else: - showlogrecords = self.api("plugins.core.settings:get")( - self.plugin_id, "showLogRecords" - ) + showlogrecords = self.api("plugins.core.settings:get")(self.plugin_id, "showLogRecords") update_filter = [] if showlogrecords else ["LogRecord"] data = record.get_formatted_details( update_filter=update_filter, diff --git a/src/bastproxy/plugins/test/newmon/types/trackedrecord.py b/src/bastproxy/plugins/test/newmon/types/trackedrecord.py index ed8be369..9b8bf896 100644 --- a/src/bastproxy/plugins/test/newmon/types/trackedrecord.py +++ b/src/bastproxy/plugins/test/newmon/types/trackedrecord.py @@ -167,9 +167,7 @@ class name and tracking UUID will be used. self.owner_id = f"{owner_id}-{self._tracking_uuid}" else: self.owner_id = f"{self.__class__.__name__}-{self._tracking_uuid}" - self.api = API( - owner_id=self.owner_id or f"{self.__class__.__name__}:{self._tracking_uuid}" - ) + self.api = API(owner_id=self.owner_id or f"{self.__class__.__name__}:{self._tracking_uuid}") self.tracking_add_observer(self._tracking_onchange) self._tracking_record_updates = [] # create a unique id for this record diff --git a/src/bastproxy/plugins/test/newmon/utils/recordchangelog.py b/src/bastproxy/plugins/test/newmon/utils/recordchangelog.py index fe129877..176d8983 100644 --- a/src/bastproxy/plugins/test/newmon/utils/recordchangelog.py +++ b/src/bastproxy/plugins/test/newmon/utils/recordchangelog.py @@ -70,19 +70,11 @@ def __hash__(self): def __eq__(self, value: object) -> bool: """Compare equality based on uuid and class type.""" - return ( - self.uuid == value.uuid - if isinstance(value, RecordChangeLogEntry) - else False - ) + return self.uuid == value.uuid if isinstance(value, RecordChangeLogEntry) else False def __lt__(self, value: object) -> bool: """Allow sorting by creation time when available.""" - return ( - self.created_time < value.created_time - if hasattr(value, "created_time") - else False - ) # type: ignore + return self.created_time < value.created_time if hasattr(value, "created_time") else False # type: ignore def fix_stack(self, stack): """Normalize the captured stack for formatting.""" @@ -162,8 +154,7 @@ def format_detailed( data = self.data tmsg.append(f"{'Data':<15} :") tmsg.extend( - f"{'':<15} : {line}" - for line in pprint.pformat(data, width=120).splitlines() + f"{'':<15} : {line}" for line in pprint.pformat(data, width=120).splitlines() ) if show_stack and self.stack: tmsg.append(f"{'Stack':<15} :") diff --git a/tests/integration/test_proxy_integration.py b/tests/integration/test_proxy_integration.py index b204866c..c4a0850f 100644 --- a/tests/integration/test_proxy_integration.py +++ b/tests/integration/test_proxy_integration.py @@ -17,9 +17,7 @@ import pytest -pytestmark = pytest.mark.skip( - reason="Integration tests disabled - subprocess timing issues" -) +pytestmark = pytest.mark.skip(reason="Integration tests disabled - subprocess timing issues") class TestProxyConnection: @@ -76,9 +74,7 @@ async def test_banner_received(self, telnet_connection: tuple) -> None: assert b"password" in banner.lower() @pytest.mark.asyncio - async def test_authentication_with_default_password( - self, telnet_connection: tuple - ) -> None: + async def test_authentication_with_default_password(self, telnet_connection: tuple) -> None: """Test authentication with the default password. Args: @@ -106,9 +102,7 @@ async def test_authentication_with_default_password( assert b"logged in" in response.lower() @pytest.mark.asyncio - async def test_authentication_with_wrong_password( - self, telnet_connection: tuple - ) -> None: + async def test_authentication_with_wrong_password(self, telnet_connection: tuple) -> None: """Test that wrong password is rejected. Args: @@ -133,9 +127,7 @@ async def test_authentication_with_wrong_password( # Read response response = await asyncio.wait_for(reader.read(2048), timeout=2.0) - assert ( - b"Invalid password" in response or b"invalid password" in response.lower() - ) + assert b"Invalid password" in response or b"invalid password" in response.lower() class TestProxyCommands: diff --git a/tests/libs/test_api.py b/tests/libs/test_api.py index 3e61f94e..47b2e72f 100644 --- a/tests/libs/test_api.py +++ b/tests/libs/test_api.py @@ -82,9 +82,7 @@ def test_api_add_class_level(self) -> None: def test_api_add_instance_level(self) -> None: """Test adding a function to the instance-level API.""" api = API(owner_id="test_owner") - result = api.add( - "test", "function", helper_function_one, instance=True, description="Test" - ) + result = api.add("test", "function", helper_function_one, instance=True, description="Test") assert result is True # Check that the API exists at instance level @@ -203,9 +201,7 @@ def test_api_add_duplicate_with_force_succeeds(self) -> None: """Test that adding duplicate API with force=True overwrites.""" api = API(owner_id="test_owner") result1 = api.add("test", "function", helper_function_one, description="Test 1") - result2 = api.add( - "test", "function", helper_function_two, force=True, description="Test 2" - ) + result2 = api.add("test", "function", helper_function_two, force=True, description="Test 2") assert result1 is True assert result2 is True # Should succeed with force=True @@ -229,9 +225,7 @@ def test_api_overwritten_api_reference(self) -> None: """Test that overwritten_api stores reference to original API.""" api = API(owner_id="test_owner") api.add("test", "function", helper_function_one, description="Test 1") - api.add( - "test", "function", helper_function_two, force=True, description="Test 2" - ) + api.add("test", "function", helper_function_two, force=True, description="Test 2") # Check that overwritten_api reference exists api_item = api.get("test:function") @@ -283,9 +277,7 @@ def test_instance_and_class_apis_separate(self) -> None: api = API(owner_id="test_owner") # Add to class API - result1 = api.add( - "testsep", "classapi", helper_function_one, description="Class" - ) + result1 = api.add("testsep", "classapi", helper_function_one, description="Class") # Add to instance API with different name result2 = api.add( diff --git a/tests/libs/test_argp.py b/tests/libs/test_argp.py index 0ff9a9c5..56345173 100644 --- a/tests/libs/test_argp.py +++ b/tests/libs/test_argp.py @@ -143,9 +143,7 @@ def test_formatter_with_required_arg(self) -> None: def test_formatter_with_multiple_args(self) -> None: """Test formatter with multiple arguments.""" - parser = ArgumentParser( - description="Test parser", formatter_class=CustomFormatter - ) + parser = ArgumentParser(description="Test parser", formatter_class=CustomFormatter) parser.add_argument("--name", default="default_name", help="Name parameter") parser.add_argument("--count", type=int, default=5, help="Count parameter") parser.add_argument("--verbose", action="store_true", help="Verbose output") diff --git a/tests/libs/test_callback.py b/tests/libs/test_callback.py index 858d07ff..eaa8d32f 100644 --- a/tests/libs/test_callback.py +++ b/tests/libs/test_callback.py @@ -71,9 +71,7 @@ class TestCallbackBasics: def test_callback_creation(self) -> None: """Test that callbacks can be created with required parameters.""" - callback = Callback( - name="test_callback", owner_id="test_owner", func=helper_test_function - ) + callback = Callback(name="test_callback", owner_id="test_owner", func=helper_test_function) assert callback.name == "test_callback" assert callback.owner_id == "test_owner" @@ -94,18 +92,14 @@ def test_callback_creation_with_disabled(self) -> None: def test_callback_has_creation_time(self) -> None: """Test that callbacks track their creation time.""" - callback = Callback( - name="test_callback", owner_id="test_owner", func=helper_test_function - ) + callback = Callback(name="test_callback", owner_id="test_owner", func=helper_test_function) assert callback.created_time is not None assert callback.last_raised_datetime is None def test_callback_string_representation(self) -> None: """Test the string representation of a callback.""" - callback = Callback( - name="test_cb", owner_id="test_owner", func=helper_test_function - ) + callback = Callback(name="test_cb", owner_id="test_owner", func=helper_test_function) str_repr = str(callback) @@ -120,9 +114,7 @@ def test_execute_callback_without_args(self) -> None: """Test executing a callback without arguments.""" test_execution_log.clear() - callback = Callback( - name="exec_callback", owner_id="test_owner", func=helper_test_function - ) + callback = Callback(name="exec_callback", owner_id="test_owner", func=helper_test_function) result = callback.execute() @@ -149,9 +141,7 @@ def test_execute_callback_with_args(self) -> None: def test_execute_increments_raised_count(self) -> None: """Test that executing a callback increments the raised count.""" - callback = Callback( - name="count_callback", owner_id="test_owner", func=helper_test_function - ) + callback = Callback(name="count_callback", owner_id="test_owner", func=helper_test_function) assert callback.raised_count == 0 @@ -166,9 +156,7 @@ def test_execute_increments_raised_count(self) -> None: def test_execute_updates_last_raised_time(self) -> None: """Test that executing a callback updates last raised time.""" - callback = Callback( - name="time_callback", owner_id="test_owner", func=helper_test_function - ) + callback = Callback(name="time_callback", owner_id="test_owner", func=helper_test_function) assert callback.last_raised_datetime is None @@ -193,9 +181,7 @@ class TestCallbackEquality: def test_callback_hash(self) -> None: """Test that callbacks have a hash value.""" - callback = Callback( - name="hash_callback", owner_id="test_owner", func=helper_test_function - ) + callback = Callback(name="hash_callback", owner_id="test_owner", func=helper_test_function) hash_value = hash(callback) @@ -203,20 +189,14 @@ def test_callback_hash(self) -> None: def test_different_callbacks_different_hashes(self) -> None: """Test that different callbacks have different hashes.""" - callback1 = Callback( - name="callback1", owner_id="owner1", func=helper_test_function - ) - callback2 = Callback( - name="callback2", owner_id="owner2", func=another_test_function - ) + callback1 = Callback(name="callback1", owner_id="owner1", func=helper_test_function) + callback2 = Callback(name="callback2", owner_id="owner2", func=another_test_function) assert hash(callback1) != hash(callback2) def test_callback_identity(self) -> None: """Test that a callback has consistent identity.""" - callback = Callback( - name="self_callback", owner_id="test_owner", func=helper_test_function - ) + callback = Callback(name="self_callback", owner_id="test_owner", func=helper_test_function) # Test that the callback object has a consistent hash hash1 = hash(callback) @@ -226,16 +206,12 @@ def test_callback_identity(self) -> None: def test_callback_equals_same_callback(self) -> None: """Test that two callbacks with same properties are equal.""" # Note: They won't be equal because creation_time differs - callback1 = Callback( - name="same_callback", owner_id="test_owner", func=helper_test_function - ) + callback1 = Callback(name="same_callback", owner_id="test_owner", func=helper_test_function) # Small delay to ensure timestamps differ (Windows has low clock resolution) time.sleep(0.001) - callback2 = Callback( - name="same_callback", owner_id="test_owner", func=helper_test_function - ) + callback2 = Callback(name="same_callback", owner_id="test_owner", func=helper_test_function) # These are NOT equal because they have different creation times assert callback1 != callback2 @@ -243,18 +219,14 @@ def test_callback_equals_same_callback(self) -> None: def test_callback_equals_wrapped_function(self) -> None: """Test that a callback equals its wrapped function.""" - callback = Callback( - name="func_callback", owner_id="test_owner", func=helper_test_function - ) + callback = Callback(name="func_callback", owner_id="test_owner", func=helper_test_function) # Callback should equal its wrapped function assert callback == helper_test_function def test_callback_not_equals_different_function(self) -> None: """Test that a callback doesn't equal a different function.""" - callback = Callback( - name="diff_callback", owner_id="test_owner", func=helper_test_function - ) + callback = Callback(name="diff_callback", owner_id="test_owner", func=helper_test_function) assert callback != another_test_function diff --git a/tools/doit/__init__.py b/tools/doit/__init__.py index 487e37da..3ff3da2d 100644 --- a/tools/doit/__init__.py +++ b/tools/doit/__init__.py @@ -1,4 +1,4 @@ -"""Dodo task modules for bastproxy. +"""Dodo task modules for the pyproject-template. This package contains modular doit task definitions organized by functionality. Tasks are auto-discovered from all modules in this package. diff --git a/tools/doit/adr.py b/tools/doit/adr.py new file mode 100644 index 00000000..3ec05ae8 --- /dev/null +++ b/tools/doit/adr.py @@ -0,0 +1,352 @@ +"""Architecture Decision Records (ADR) doit tasks.""" + +import os +import re +import subprocess # nosec B404 - subprocess is required for doit tasks +import sys +import tempfile +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from doit.tools import title_with_actions +from rich.console import Console +from rich.panel import Panel + +from tools.doit.templates import get_adr_required_sections, get_adr_template + +if TYPE_CHECKING: + from rich.console import Console as ConsoleType + +ADR_DIR = Path("docs/decisions") + + +def _get_next_adr_number() -> int: + """Get the next available ADR number. + + Scans existing ADR files and returns the next sequential number. + + Returns: + Next ADR number (1 if no ADRs exist) + """ + if not ADR_DIR.exists(): + return 1 + + pattern = re.compile(r"^(\d{4})-.*\.md$") + max_number = 0 + + for file in ADR_DIR.iterdir(): + if file.name == "adr-template.md" or file.name == "README.md": + continue + match = pattern.match(file.name) + if match: + number = int(match.group(1)) + max_number = max(max_number, number) + + return max_number + 1 + + +def _title_to_slug(title: str) -> str: + """Convert a title to a kebab-case slug. + + Args: + title: The ADR title + + Returns: + Kebab-case slug suitable for filename + """ + # Convert to lowercase + slug = title.lower() + # Replace spaces and underscores with hyphens + slug = re.sub(r"[\s_]+", "-", slug) + # Remove non-alphanumeric characters except hyphens + slug = re.sub(r"[^a-z0-9-]", "", slug) + # Collapse multiple hyphens + slug = re.sub(r"-+", "-", slug) + # Trim hyphens from ends + slug = slug.strip("-") + return slug + + +def _get_editor() -> str: + """Get the user's preferred editor.""" + return os.environ.get("EDITOR", os.environ.get("VISUAL", "vi")) + + +def _open_editor_with_template(template: str, suffix: str = ".md") -> str | None: + """Open editor with template and return the edited content. + + Args: + template: The template content to start with + suffix: File suffix for the temp file + + Returns: + The edited content, or None if aborted/unchanged + """ + console = Console() + editor = _get_editor() + + # Create temp file with template + with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as f: + f.write(template) + temp_path = f.name + + try: + # Open editor + console.print(f"[dim]Opening {editor}...[/dim]") + result = subprocess.run([editor, temp_path]) + + if result.returncode != 0: + console.print("[red]Editor exited with error.[/red]") + return None + + # Read the edited content + with open(temp_path) as f: + content = f.read() + + # Remove HTML comments + edited = re.sub(r"", "", content, flags=re.DOTALL) + + # Clean up extra blank lines + edited = re.sub(r"\n{3,}", "\n\n", edited).strip() + + return edited + + finally: + # Clean up temp file + if os.path.exists(temp_path): + os.remove(temp_path) + + +def _read_body_file(file_path: str, console: "ConsoleType") -> str | None: + """Read body content from a file. + + Args: + file_path: Path to the file + console: Rich console for output + + Returns: + File content, or None if error + """ + path = Path(file_path) + if not path.exists(): + console.print(f"[red]File not found: {file_path}[/red]") + return None + + try: + return path.read_text() + except Exception as e: + console.print(f"[red]Error reading file: {e}[/red]") + return None + + +def _validate_adr_content(content: str, console: "ConsoleType") -> bool: + """Validate that ADR has required sections with content. + + Required sections are read from the ADR template file. + + Args: + content: ADR markdown content + console: Rich console for output + + Returns: + True if valid, False otherwise + """ + required_sections = get_adr_required_sections() + + for section in required_sections: + # Look for ## Section header followed by content + # Use MULTILINE so ^ matches start of line + pattern = rf"^##\s+{re.escape(section)}\s*\n(.*?)(?=^##|\Z)" + match = re.search(pattern, content, re.DOTALL | re.IGNORECASE | re.MULTILINE) + + if not match: + console.print(f"[red]Missing required section: {section}[/red]") + return False + + section_content = match.group(1).strip() + if not section_content or _is_placeholder_content(section_content): + console.print( + f"[red]Section '{section}' is empty or contains only placeholder text.[/red]" + ) + return False + + return True + + +def _is_placeholder_content(content: str) -> bool: + """Check if content is just placeholder text from the template. + + Args: + content: Section content to check + + Returns: + True if content appears to be placeholder text + """ + placeholder_patterns = [ + r"^brief summary", + r"^why this decision", + r"^issue #xx", + ] + content_lower = content.lower().strip() + return any(re.match(pattern, content_lower) for pattern in placeholder_patterns) + + +def _prepare_editor_template(title: str, number: int, date: str) -> str: + """Prepare the editor template with title, number, and date filled in. + + Args: + title: ADR title + number: ADR number + date: Date string (YYYY-MM-DD format) + + Returns: + Template content ready for editing + """ + adr_template = get_adr_template() + template = adr_template.editor_template + + # Replace placeholders + template = template.replace("ADR-NNNN: Title", f"ADR-{number:04d}: {title}") + template = template.replace("ADR-NNNN:", f"ADR-{number:04d}:") + template = template.replace("YYYY-MM-DD", date) + + return template + + +def task_adr() -> dict[str, Any]: + """Create a new Architecture Decision Record (ADR). + + Creates a new ADR file with the next sequential number. + Required sections are determined by the ADR template file. + + Three modes: + 1. Interactive (default): Opens $EDITOR with template + 2. --body-file: Reads body from a file + 3. --title + --body: Provides content directly (for AI/scripts) + + Examples: + Interactive: doit adr --title="Use Redis for caching" + From file: doit adr --title="Use Redis" --body-file=adr.md + Direct: doit adr --title="Use Redis" --body="## Status\\nAccepted\\n..." + """ + + def create_adr( + title: str | None = None, + body: str | None = None, + body_file: str | None = None, + ) -> None: + console = Console() + console.print() + console.print( + Panel.fit( + "[bold cyan]Creating Architecture Decision Record[/bold cyan]", + border_style="cyan", + ) + ) + console.print() + + # Ensure ADR directory exists + ADR_DIR.mkdir(parents=True, exist_ok=True) + + # Get title if not provided + if not title: + console.print("[cyan]ADR title:[/cyan]") + title = input("> ").strip() + if not title: + console.print("[red]Title is required.[/red]") + sys.exit(1) + + # Generate filename and number + number = _get_next_adr_number() + slug = _title_to_slug(title) + filename = f"{number:04d}-{slug}.md" + adr_path = ADR_DIR / filename + today = datetime.now().strftime("%Y-%m-%d") + + console.print(f"[dim]ADR number: {number:04d}[/dim]") + console.print(f"[dim]Filename: {filename}[/dim]") + + # Show required sections + required = get_adr_required_sections() + console.print(f"[dim]Required sections: {', '.join(required)}[/dim]") + + # Determine body content + if body_file: + # Mode 2: Read from file + body_content = _read_body_file(body_file, console) + if body_content is None: + sys.exit(1) + elif body: + # Mode 3: Direct body provided + body_content = body + else: + # Mode 1: Interactive editor + template = _prepare_editor_template(title, number, today) + + console.print( + "[dim]Opening editor with ADR template. Fill in the sections, save, and exit.[/dim]" + ) + body_content = _open_editor_with_template(template) + if body_content is None: + console.print("[yellow]Aborted.[/yellow]") + sys.exit(0) + + # For non-interactive modes, ensure header is correct + if body_file or body: + # Check if content already has a header + if not body_content.startswith("# ADR-"): + # Prepend the header + body_content = f"# ADR-{number:04d}: {title}\n\n{body_content}" + else: + # Replace the header with correct number + body_content = re.sub( + r"^# ADR-\d+: .+", + f"# ADR-{number:04d}: {title}", + body_content, + ) + + # Ensure date is set + if "YYYY-MM-DD" in body_content: + body_content = body_content.replace("YYYY-MM-DD", today) + + # Validate content + if not _validate_adr_content(body_content, console): + console.print("[red]ADR content validation failed.[/red]") + sys.exit(1) + + # Write ADR file + adr_path.write_text(body_content + "\n") + + console.print() + console.print( + Panel.fit( + f"[bold green]ADR created successfully![/bold green]\n\n{adr_path}", + border_style="green", + ) + ) + + return { + "actions": [create_adr], + "params": [ + { + "name": "title", + "long": "title", + "default": None, + "help": "ADR title (e.g., 'Use Redis for caching')", + }, + { + "name": "body", + "long": "body", + "default": None, + "help": "ADR body content (markdown)", + }, + { + "name": "body_file", + "long": "body-file", + "default": None, + "help": "Read body from file", + }, + ], + "title": title_with_actions, + } diff --git a/tools/doit/base.py b/tools/doit/base.py index 44ed06db..8f784d79 100644 --- a/tools/doit/base.py +++ b/tools/doit/base.py @@ -23,9 +23,7 @@ def success_message() -> None: console.print() console.print( Panel.fit( - "[bold green]\u2713 All checks passed![/bold green]", - border_style="green", - padding=(1, 2), + "[bold green]✓ All checks passed![/bold green]", border_style="green", padding=(1, 2) ) ) console.print() diff --git a/tools/doit/build.py b/tools/doit/build.py index e501d98d..0c5646d1 100644 --- a/tools/doit/build.py +++ b/tools/doit/build.py @@ -22,7 +22,7 @@ def publish_cmd() -> str: token = os.environ.get("PYPI_TOKEN") if not token: raise RuntimeError("PYPI_TOKEN environment variable must be set.") - return f"uv publish --token '{token}'" + return "uv publish --token '{token}'" return { "actions": ["uv build", CmdAction(publish_cmd)], diff --git a/tools/doit/docs.py b/tools/doit/docs.py index 596245a3..3410e2e0 100644 --- a/tools/doit/docs.py +++ b/tools/doit/docs.py @@ -32,7 +32,16 @@ def task_docs_deploy() -> dict[str, Any]: def task_spell_check() -> dict[str, Any]: """Check spelling in code and documentation.""" return { - "actions": ["uv run codespell src/ tests/ docs/ README.md"], + "actions": ["uv run codespell src/ tests/ tools/ docs/ bootstrap.py README.md"], "title": title_with_actions, "verbosity": 0, } + + +def task_docs_toc() -> dict[str, Any]: + """Generate documentation table of contents from frontmatter.""" + return { + "actions": ["uv run python tools/generate_doc_toc.py"], + "title": title_with_actions, + "verbosity": 2, + } diff --git a/tools/doit/git.py b/tools/doit/git.py index 5be12b48..55fbdcd1 100644 --- a/tools/doit/git.py +++ b/tools/doit/git.py @@ -8,9 +8,7 @@ def task_commit() -> dict[str, Any]: """Interactive commit with commitizen (ensures conventional commit format).""" return { - "actions": [ - "uv run cz commit || echo 'commitizen not installed. Run: uv sync'" - ], + "actions": ["uv run cz commit || echo 'commitizen not installed. Run: uv sync'"], "title": title_with_actions, } @@ -26,9 +24,7 @@ def task_bump() -> dict[str, Any]: def task_changelog() -> dict[str, Any]: """Generate CHANGELOG from conventional commits.""" return { - "actions": [ - "uv run cz changelog || echo 'commitizen not installed. Run: uv sync'" - ], + "actions": ["uv run cz changelog || echo 'commitizen not installed. Run: uv sync'"], "title": title_with_actions, } diff --git a/tools/doit/github.py b/tools/doit/github.py new file mode 100644 index 00000000..10d85eb3 --- /dev/null +++ b/tools/doit/github.py @@ -0,0 +1,674 @@ +"""GitHub issue and PR creation doit tasks.""" + +import os +import re +import subprocess # nosec B404 - subprocess is required for doit tasks +import sys +import tempfile +from typing import TYPE_CHECKING, Any + +from doit.tools import title_with_actions +from rich.console import Console +from rich.panel import Panel + +from tools.doit.templates import get_issue_template, get_pr_template, get_required_sections + +if TYPE_CHECKING: + from rich.console import Console as ConsoleType + + +def _get_editor() -> str: + """Get the user's preferred editor.""" + return os.environ.get("EDITOR", os.environ.get("VISUAL", "vi")) + + +def _open_editor_with_template(template: str, suffix: str = ".md") -> str | None: + """Open editor with template and return the edited content. + + Args: + template: The template content to start with + suffix: File suffix for the temp file + + Returns: + The edited content (without comment lines), or None if aborted/unchanged + """ + console = Console() + editor = _get_editor() + + # Create temp file with template + with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as f: + f.write(template) + temp_path = f.name + + try: + # Open editor + console.print(f"[dim]Opening {editor}...[/dim]") + result = subprocess.run([editor, temp_path]) + + if result.returncode != 0: + console.print("[red]Editor exited with error.[/red]") + return None + + # Read the edited content + with open(temp_path) as f: + content = f.read() + + # Remove comment lines (starting with #) and HTML comments + lines = [] + for line in content.split("\n"): + stripped = line.strip() + if stripped.startswith("#") and not stripped.startswith("##"): + continue # Skip comment lines but keep ## headers + lines.append(line) + + edited = "\n".join(lines) + + # Remove HTML comments + edited = re.sub(r"", "", edited, flags=re.DOTALL) + + # Clean up extra blank lines + edited = re.sub(r"\n{3,}", "\n\n", edited).strip() + + return edited + + finally: + # Clean up temp file + if os.path.exists(temp_path): + os.remove(temp_path) + + +def _parse_markdown_sections(content: str) -> dict[str, str]: + """Parse markdown content into sections by ## headers. + + Args: + content: Markdown content with ## headers + + Returns: + Dict mapping section names to their content + """ + sections: dict[str, str] = {} + current_section = "" + current_content: list[str] = [] + + for line in content.split("\n"): + if line.startswith("## "): + # Save previous section + if current_section: + sections[current_section] = "\n".join(current_content).strip() + # Start new section + current_section = line[3:].strip() + current_content = [] + else: + current_content.append(line) + + # Save last section + if current_section: + sections[current_section] = "\n".join(current_content).strip() + + return sections + + +def _validate_issue_content( + sections: dict[str, str], issue_type: str, console: "ConsoleType" +) -> bool: + """Validate that required sections have content. + + Args: + sections: Parsed markdown sections + issue_type: Type of issue (feature, bug, refactor, doc, chore) + console: Rich console for output + + Returns: + True if valid, False otherwise + """ + # Get required sections dynamically from templates + required_sections = get_required_sections(issue_type) + + missing = [] + placeholder_patterns = [ + "describe the", + "clear description", + "paste error", + "any other relevant", + "delete section if not needed", + ] + + for section_name in required_sections: + content = sections.get(section_name, "").strip() + if not content: + missing.append(section_name) + continue + + # Check for placeholder text + content_lower = content.lower() + for pattern in placeholder_patterns: + if pattern in content_lower: + console.print( + f"[yellow]Warning: '{section_name}' may contain placeholder text.[/yellow]" + ) + break + + if missing: + console.print(f"[red]Missing required sections: {', '.join(missing)}[/red]") + return False + + return True + + +def _read_body_file(file_path: str, console: "ConsoleType") -> str | None: + """Read body content from a file. + + Args: + file_path: Path to the file + console: Rich console for output + + Returns: + File content, or None if error + """ + from pathlib import Path + + path = Path(file_path) + if not path.exists(): + console.print(f"[red]File not found: {file_path}[/red]") + return None + + try: + return path.read_text() + except Exception as e: + console.print(f"[red]Error reading file: {e}[/red]") + return None + + +def task_issue() -> dict[str, Any]: + """Create issue from .github/ISSUE_TEMPLATE/.yml (feature/bug/refactor/doc/chore). + + Labels are automatically applied based on the issue type. + + Three modes: + 1. Interactive (default): Opens $EDITOR with template + 2. --body-file: Reads body from a file + 3. --title + --body: Provides content directly (for AI/scripts) + + Examples: + Interactive: doit issue --type=feature + From file: doit issue --type=doc --title="Add guide" --body-file=issue.md + Direct: doit issue --type=chore --title="Update CI" --body="## Description\\n..." + """ + + def create_issue( + type: str, + title: str | None = None, + body: str | None = None, + body_file: str | None = None, + ) -> None: + console = Console() + console.print() + console.print( + Panel.fit( + f"[bold cyan]Creating {type} Issue[/bold cyan]", + border_style="cyan", + ) + ) + console.print() + + # Get template (validates type and retrieves labels/template dynamically) + try: + issue_template = get_issue_template(type) + except ValueError as e: + console.print(f"[red]{e}[/red]") + sys.exit(1) + except FileNotFoundError as e: + console.print(f"[red]Template error: {e}[/red]") + sys.exit(1) + + labels = issue_template.labels + + # Determine body content + if body_file: + # Mode 2: Read from file + body_content = _read_body_file(body_file, console) + if body_content is None: + sys.exit(1) + elif body: + # Mode 3: Direct body provided + body_content = body + else: + # Mode 1: Interactive editor + console.print( + f"[dim]Opening editor with {type} template. " + "Fill in the sections, save, and exit.[/dim]" + ) + body_content = _open_editor_with_template(issue_template.editor_template) + if body_content is None: + console.print("[yellow]Aborted.[/yellow]") + sys.exit(0) + + # Parse and validate + sections = _parse_markdown_sections(body_content) + if not _validate_issue_content(sections, type, console): + console.print("[red]Issue content validation failed.[/red]") + console.print() + console.print(f"[yellow]See template: .github/ISSUE_TEMPLATE/{type}.yml[/yellow]") + sys.exit(1) + + # Get title if not provided + if not title: + console.print("[cyan]Issue title:[/cyan]") + title = input("> ").strip() + if not title: + console.print("[red]Title is required.[/red]") + sys.exit(1) + + # Create the issue + console.print("\n[cyan]Creating issue...[/cyan]") + try: + result = subprocess.run( + [ + "gh", + "issue", + "create", + "--title", + title, + "--body", + body_content, + "--label", + labels, + ], + capture_output=True, + text=True, + check=True, + ) + issue_url = result.stdout.strip() + console.print() + console.print( + Panel.fit( + f"[bold green]Issue created successfully![/bold green]\n\n{issue_url}", + border_style="green", + ) + ) + except subprocess.CalledProcessError as e: + console.print("[red]Failed to create issue:[/red]") + console.print(f"[red]{e.stderr}[/red]") + sys.exit(1) + + return { + "actions": [create_issue], + "params": [ + { + "name": "type", + "short": "t", + "long": "type", + "default": "feature", + "help": "Issue type: feature, bug, refactor, doc, chore", + }, + {"name": "title", "long": "title", "default": None, "help": "Issue title"}, + {"name": "body", "long": "body", "default": None, "help": "Issue body (markdown)"}, + { + "name": "body_file", + "long": "body-file", + "default": None, + "help": "Read body from file", + }, + ], + "title": title_with_actions, + } + + +def task_pr() -> dict[str, Any]: + """Create PR from .github/pull_request_template.md (auto-detects branch and linked issue). + + Three modes: + 1. Interactive (default): Opens $EDITOR with template + 2. --body-file: Reads body from a file + 3. --title + --body: Provides content directly (for AI/scripts) + + Examples: + Interactive: doit pr + From file: doit pr --title="feat: add export" --body-file=pr.md + Direct: doit pr --title="feat: add export" --body="## Description\\n..." + """ + + def create_pr( + title: str | None = None, + body: str | None = None, + body_file: str | None = None, + draft: bool = False, + ) -> None: + console = Console() + console.print() + console.print( + Panel.fit("[bold cyan]Creating Pull Request[/bold cyan]", border_style="cyan") + ) + console.print() + + # Check we're not on main + current_branch = subprocess.run( + ["git", "branch", "--show-current"], + capture_output=True, + text=True, + check=True, + ).stdout.strip() + + if current_branch == "main": + console.print("[red]Cannot create PR from main branch.[/red]") + console.print("[yellow]Create a feature branch first.[/yellow]") + sys.exit(1) + + console.print(f"[dim]Current branch: {current_branch}[/dim]") + + # Try to extract issue number from branch name (e.g., feat/42-description) + detected_issue = None + branch_issue_match = re.search(r"/(\d+)-", current_branch) + if branch_issue_match: + detected_issue = branch_issue_match.group(1) + console.print(f"[dim]Detected issue from branch: #{detected_issue}[/dim]") + + # Determine body content + if body_file: + # Mode 2: Read from file + body_content = _read_body_file(body_file, console) + if body_content is None: + sys.exit(1) + elif body: + # Mode 3: Direct body provided + body_content = body + else: + # Mode 1: Interactive editor + # Get PR template dynamically from .github/ + try: + template = get_pr_template() + except FileNotFoundError as e: + console.print(f"[red]Template error: {e}[/red]") + sys.exit(1) + + # Pre-fill issue number if detected + if detected_issue: + template = template.replace("#(issue)", f"#{detected_issue}") + + console.print( + "[dim]Opening editor with PR template. Fill in the sections, save, and exit.[/dim]" + ) + body_content = _open_editor_with_template(template) + if body_content is None: + console.print("[yellow]Aborted.[/yellow]") + sys.exit(0) + + # Validate PR has description + sections = _parse_markdown_sections(body_content) + description = sections.get("Description", "").strip() + if not description: + console.print("[red]Description is required.[/red]") + console.print() + console.print("[yellow]See template: .github/pull_request_template.md[/yellow]") + sys.exit(1) + + # Get title if not provided + if not title: + console.print("[cyan]PR title (e.g., 'feat: add export feature'):[/cyan]") + title = input("> ").strip() + if not title: + console.print("[red]Title is required.[/red]") + sys.exit(1) + + # Create the PR + console.print("\n[cyan]Creating PR...[/cyan]") + cmd = ["gh", "pr", "create", "--title", title, "--body", body_content] + if draft: + cmd.append("--draft") + + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + pr_url = result.stdout.strip() + console.print() + console.print( + Panel.fit( + f"[bold green]PR created successfully![/bold green]\n\n{pr_url}", + border_style="green", + ) + ) + except subprocess.CalledProcessError as e: + console.print("[red]Failed to create PR:[/red]") + console.print(f"[red]{e.stderr}[/red]") + sys.exit(1) + + return { + "actions": [create_pr], + "params": [ + {"name": "title", "long": "title", "default": None, "help": "PR title"}, + {"name": "body", "long": "body", "default": None, "help": "PR body (markdown)"}, + { + "name": "body_file", + "long": "body-file", + "default": None, + "help": "Read body from file", + }, + { + "name": "draft", + "long": "draft", + "type": bool, + "default": False, + "help": "Create as draft PR", + }, + ], + "title": title_with_actions, + } + + +def _get_pr_info(pr_number: str | None, console: "ConsoleType") -> dict[str, Any] | None: + """Get PR information from GitHub API. + + Args: + pr_number: PR number, or None to use current branch's PR + console: Rich console for output + + Returns: + Dict with PR info, or None if not found + """ + import json + + cmd = ["gh", "pr", "view", "--json", "number,title,body,state"] + if pr_number: + cmd.append(pr_number) + + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + data: dict[str, Any] = json.loads(result.stdout) + return data + except subprocess.CalledProcessError as e: + if "no pull requests found" in e.stderr.lower(): + console.print("[red]No PR found for current branch.[/red]") + else: + console.print(f"[red]Failed to get PR info: {e.stderr}[/red]") + return None + + +def _extract_linked_issues(body: str) -> dict[str, list[str]]: + """Extract linked issue numbers from PR body with relationship type. + + Looks for patterns like: + - Closes #123, Fixes #456, Resolves #789 → "closes" + - Part of #101 → "part_of" + + Args: + body: PR body text + + Returns: + Dict with "closes" and "part_of" keys, each containing list of issue numbers + """ + result: dict[str, list[str]] = {"closes": [], "part_of": []} + seen: set[str] = set() + + # Pattern for closes/fixes/resolves + closes_pattern = r"(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)\s+#(\d+)" + for match in re.finditer(closes_pattern, body, re.IGNORECASE): + issue = match.group(1) + if issue not in seen: + seen.add(issue) + result["closes"].append(issue) + + # Pattern for part of + part_of_pattern = r"part\s+of\s+#(\d+)" + for match in re.finditer(part_of_pattern, body, re.IGNORECASE): + issue = match.group(1) + if issue not in seen: + seen.add(issue) + result["part_of"].append(issue) + + return result + + +def _format_merge_subject(title: str, pr_number: int, issues: dict[str, list[str]]) -> str: + """Format the merge commit subject line. + + Args: + title: PR title (should be in conventional commit format) + pr_number: PR number + issues: Dict with "closes" and "part_of" keys containing issue numbers + + Returns: + Formatted subject: ": (merges PR #XX, closes #YY)" + or: ": (merges PR #XX, part of #YY)" + """ + closes = issues.get("closes", []) + part_of = issues.get("part_of", []) + + # Build the suffix parts + parts = [f"merges PR #{pr_number}"] + + if closes: + issue_refs = ", ".join(f"#{i}" for i in closes) + parts.append(f"closes {issue_refs}") + + if part_of: + issue_refs = ", ".join(f"#{i}" for i in part_of) + parts.append(f"part of {issue_refs}") + + suffix = f"({', '.join(parts)})" + return f"{title} {suffix}" + + +def task_pr_merge() -> dict[str, Any]: + """Merge a PR with properly formatted commit message. + + This task enforces the merge commit format: + : (merges PR #XX, closes #YY) + + Uses squash merge with a custom subject line to ensure consistent + commit history that matches the documented format. + + Examples: + doit pr_merge # Merge PR for current branch + doit pr_merge --pr=123 # Merge specific PR + doit pr_merge --delete-branch # Also delete the branch after merge + """ + + def merge_pr( + pr: str | None = None, + delete_branch: bool = True, + ) -> None: + console = Console() + console.print() + console.print(Panel.fit("[bold cyan]Merging Pull Request[/bold cyan]", border_style="cyan")) + console.print() + + # Get PR info + pr_info = _get_pr_info(pr, console) + if not pr_info: + sys.exit(1) + + pr_number = pr_info["number"] + pr_title = pr_info["title"] + pr_body = pr_info.get("body", "") or "" + pr_state = pr_info.get("state", "") + + console.print(f"[dim]PR #{pr_number}: {pr_title}[/dim]") + + # Check PR state + if pr_state.upper() != "OPEN": + console.print(f"[red]PR is not open (state: {pr_state}).[/red]") + sys.exit(1) + + # Validate PR title format + title_pattern = re.compile(r"^(feat|fix|refactor|docs|test|chore|ci|perf)(\(.+\))?:\s.+") + if not title_pattern.match(pr_title): + console.print("[red]PR title does not follow conventional commit format.[/red]") + console.print("[yellow]Expected: : [/yellow]") + console.print("[yellow]Example: feat: add user authentication[/yellow]") + sys.exit(1) + + # Extract linked issues + issues = _extract_linked_issues(pr_body) + has_issues = issues["closes"] or issues["part_of"] + if has_issues: + all_issues = issues["closes"] + issues["part_of"] + console.print(f"[dim]Linked issues: {', '.join(f'#{i}' for i in all_issues)}[/dim]") + else: + console.print("[yellow]Warning: No linked issues found in PR body.[/yellow]") + + # Format merge subject + merge_subject = _format_merge_subject(pr_title, pr_number, issues) + console.print("\n[cyan]Merge commit subject:[/cyan]") + console.print(f" [bold]{merge_subject}[/bold]") + + # Build merge command + cmd = [ + "gh", + "pr", + "merge", + str(pr_number), + "--squash", + "--subject", + merge_subject, + ] + if delete_branch: + cmd.append("--delete-branch") + + # Execute merge + console.print("\n[cyan]Merging...[/cyan]") + try: + subprocess.run(cmd, check=True) + console.print() + console.print( + Panel.fit( + f"[bold green]PR #{pr_number} merged successfully![/bold green]\n\n" + f"Commit: {merge_subject}", + border_style="green", + ) + ) + + # Reminder to update linked issues + console.print() + console.print( + Panel.fit( + "[bold yellow]Reminder: Update linked issues[/bold yellow]\n\n" + "Examples:\n" + f' gh issue close --comment "Fixed in PR #{pr_number}"\n' + f' gh issue comment --body "Addressed in PR #{pr_number}"', + border_style="yellow", + ) + ) + except subprocess.CalledProcessError as e: + console.print("[red]Failed to merge PR.[/red]") + if e.stderr: + console.print(f"[red]{e.stderr}[/red]") + sys.exit(1) + + return { + "actions": [merge_pr], + "params": [ + { + "name": "pr", + "long": "pr", + "default": None, + "help": "PR number to merge (default: PR for current branch)", + }, + { + "name": "delete_branch", + "long": "delete-branch", + "type": bool, + "default": True, + "help": "Delete branch after merge (default: True)", + }, + ], + "title": title_with_actions, + } diff --git a/tools/doit/install.py b/tools/doit/install.py index 2e2b4525..b7550701 100644 --- a/tools/doit/install.py +++ b/tools/doit/install.py @@ -36,7 +36,7 @@ def _install_direnv() -> None: text=True, check=True, ).stdout.strip() - print(f"\u2713 direnv already installed: {version}") + print(f"✓ direnv already installed: {version}") return print("Installing direnv...") @@ -49,7 +49,9 @@ def _install_direnv() -> None: os.makedirs(install_dir, exist_ok=True) if system == "linux": - bin_url = f"https://github.com/direnv/direnv/releases/download/v{version}/direnv.linux-amd64" + bin_url = ( + f"https://github.com/direnv/direnv/releases/download/v{version}/direnv.linux-amd64" + ) bin_path = os.path.join(install_dir, "direnv") print(f"Downloading {bin_url}...") urllib.request.urlretrieve(bin_url, bin_path) # nosec B310 - downloading from hardcoded GitHub release URL @@ -60,7 +62,7 @@ def _install_direnv() -> None: print(f"Unsupported OS: {system}") sys.exit(1) - print("\u2713 direnv installed.") + print("✓ direnv installed.") print("\nIMPORTANT: Add direnv hook to your shell:") print(" Bash: echo 'eval \"$(direnv hook bash)\"'") print(" Zsh: echo 'eval \"$(direnv hook zsh)\"'") @@ -76,7 +78,7 @@ def task_install() -> dict[str, Any]: } -def task_dev() -> dict[str, Any]: +def task_install_dev() -> dict[str, Any]: """Install package with dev dependencies.""" return { "actions": [ @@ -86,16 +88,6 @@ def task_dev() -> dict[str, Any]: } -def task_sync() -> dict[str, Any]: - """Sync virtualenv with all extras and dev deps (alias of dev).""" - return { - "actions": [ - "uv sync --all-extras --dev", - ], - "title": title_with_actions, - } - - def task_install_direnv() -> dict[str, Any]: """Install direnv for automatic environment loading.""" return { diff --git a/tools/doit/maintenance.py b/tools/doit/maintenance.py index 8781a58e..bbdf1c4f 100644 --- a/tools/doit/maintenance.py +++ b/tools/doit/maintenance.py @@ -86,7 +86,7 @@ def clean_artifacts() -> None: console.print() console.print( Panel.fit( - "[bold green]\u2713 Deep clean complete![/bold green]", + "[bold green]✓ Deep clean complete![/bold green]", border_style="green", padding=(1, 2), ) @@ -105,9 +105,7 @@ def update_dependencies() -> None: console = Console() console.print() console.print( - Panel.fit( - "[bold cyan]Updating Dependencies[/bold cyan]", border_style="cyan" - ) + Panel.fit("[bold cyan]Updating Dependencies[/bold cyan]", border_style="cyan") ) console.print() @@ -128,12 +126,11 @@ def update_dependencies() -> None: # Update dependencies and refresh lockfile result = subprocess.run( ["uv", "sync", "--all-extras", "--dev", "--upgrade"], - check=False, env={**os.environ, "UV_CACHE_DIR": UV_CACHE_DIR}, ) if result.returncode != 0: - print("\n\u274c Dependency update failed!") + print("\n❌ Dependency update failed!") sys.exit(1) print() @@ -143,12 +140,12 @@ def update_dependencies() -> None: print() # Run all checks - check_result = subprocess.run(["doit", "check"], check=False) + check_result = subprocess.run(["doit", "check"]) print() if check_result.returncode == 0: print("=" * 70) - print(" " * 20 + "\u2713 All checks passed!") + print(" " * 20 + "✓ All checks passed!") print("=" * 70) print() print("Next steps:") @@ -157,7 +154,7 @@ def update_dependencies() -> None: print("3. Commit the updated dependencies") else: print("=" * 70) - print("\u26a0 Warning: Some checks failed after update") + print("⚠ Warning: Some checks failed after update") print("=" * 70) print() print("You may need to:") @@ -178,3 +175,152 @@ def task_fmt_pyproject() -> dict[str, Any]: "actions": ["uv run pyproject-fmt pyproject.toml"], "title": title_with_actions, } + + +def task_completions() -> dict[str, Any]: + """Generate shell completion scripts for doit tasks.""" + + def generate_completions() -> None: + console = Console() + console.print() + console.print( + Panel.fit("[bold cyan]Generating Shell Completions[/bold cyan]", border_style="cyan") + ) + console.print() + + # Ensure completions directory exists + os.makedirs("completions", exist_ok=True) + + # Generate bash completion + console.print("[cyan]Generating bash completion...[/cyan]") + bash_result = subprocess.run( + ["doit", "tabcompletion", "--shell", "bash"], + capture_output=True, + text=True, + check=True, + ) + with open("completions/doit.bash", "w") as f: + f.write(bash_result.stdout) + console.print(" [dim]Created completions/doit.bash[/dim]") + + # Generate zsh completion + console.print("[cyan]Generating zsh completion...[/cyan]") + zsh_result = subprocess.run( + ["doit", "tabcompletion", "--shell", "zsh"], + capture_output=True, + text=True, + check=True, + ) + with open("completions/doit.zsh", "w") as f: + f.write(zsh_result.stdout) + console.print(" [dim]Created completions/doit.zsh[/dim]") + + console.print() + console.print( + Panel.fit( + "[bold green]✓ Completions generated![/bold green]\n\n" + "[dim]To enable, add to your shell config:[/dim]\n" + " Bash: source completions/doit.bash\n" + " Zsh: source completions/doit.zsh", + border_style="green", + padding=(1, 2), + ) + ) + + return { + "actions": [generate_completions], + "title": title_with_actions, + } + + +def task_completions_install() -> dict[str, Any]: + """Install doit completions to your shell config (~/.bashrc or ~/.zshrc).""" + + def install_completions() -> None: + console = Console() + console.print() + console.print( + Panel.fit("[bold cyan]Installing Shell Completions[/bold cyan]", border_style="cyan") + ) + console.print() + + # Get absolute path to completions + project_dir = os.path.abspath(os.getcwd()) + bash_completion = os.path.join(project_dir, "completions", "doit.bash") + zsh_completion = os.path.join(project_dir, "completions", "doit.zsh") + + # Check if completions exist + if not os.path.exists(bash_completion) or not os.path.exists(zsh_completion): + console.print("[yellow]Completions not found. Generating...[/yellow]") + subprocess.run(["doit", "completions"], check=True) + console.print() + + # Source line to add (with unique marker for identification) + project_name = os.path.basename(project_dir) + bash_source_line = ( + f"\n# Doit completions for {project_name}\n" + f'if [ -f "{bash_completion}" ]; then source "{bash_completion}"; fi\n' + ) + zsh_source_line = ( + f"\n# Doit completions for {project_name}\n" + f'if [ -f "{zsh_completion}" ]; then source "{zsh_completion}"; fi\n' + ) + + home = os.path.expanduser("~") + installed = [] + + # Install bash completion + bashrc = os.path.join(home, ".bashrc") + if os.path.exists(bashrc): + with open(bashrc) as f: + content = f.read() + if bash_completion not in content: + with open(bashrc, "a") as f: + f.write(bash_source_line) + installed.append(("Bash", bashrc)) + console.print(f"[green]✓ Added to {bashrc}[/green]") + else: + console.print(f"[dim]Already in {bashrc}[/dim]") + + # Install zsh completion + zshrc = os.path.join(home, ".zshrc") + if os.path.exists(zshrc): + with open(zshrc) as f: + content = f.read() + if zsh_completion not in content: + with open(zshrc, "a") as f: + f.write(zsh_source_line) + installed.append(("Zsh", zshrc)) + console.print(f"[green]✓ Added to {zshrc}[/green]") + else: + console.print(f"[dim]Already in {zshrc}[/dim]") + + console.print() + if installed: + shells = ", ".join(s[0] for s in installed) + console.print( + Panel.fit( + f"[bold green]✓ Completions installed for {shells}![/bold green]\n\n" + "[dim]Reload your shell or run:[/dim]\n" + " source ~/.bashrc (for Bash)\n" + " source ~/.zshrc (for Zsh)", + border_style="green", + padding=(1, 2), + ) + ) + else: + console.print( + Panel.fit( + "[yellow]No shell config files found or already installed.[/yellow]\n\n" + "[dim]Manually add to your shell config:[/dim]\n" + f' source "{bash_completion}" (Bash)\n' + f' source "{zsh_completion}" (Zsh)', + border_style="yellow", + padding=(1, 2), + ) + ) + + return { + "actions": [install_completions], + "title": title_with_actions, + } diff --git a/tools/doit/quality.py b/tools/doit/quality.py index 0781ae45..4ad801d6 100644 --- a/tools/doit/quality.py +++ b/tools/doit/quality.py @@ -10,7 +10,7 @@ def task_lint() -> dict[str, Any]: """Run ruff linting.""" return { - "actions": ["uv run ruff check src/ tests/"], + "actions": ["uv run ruff check src/ tests/ tools/ "], "title": title_with_actions, "verbosity": 0, } @@ -20,8 +20,8 @@ def task_format() -> dict[str, Any]: """Format code with ruff.""" return { "actions": [ - "uv run ruff format src/ tests/", - "uv run ruff check --fix src/ tests/", + "uv run ruff format src/ tests/ tools/ ", + "uv run ruff check --fix src/ tests/ tools/ ", ], "title": title_with_actions, } @@ -30,7 +30,7 @@ def task_format() -> dict[str, Any]: def task_format_check() -> dict[str, Any]: """Check code formatting without modifying files.""" return { - "actions": ["uv run ruff format --check src/ tests/"], + "actions": ["uv run ruff format --check src/ tests/ tools/ "], "title": title_with_actions, "verbosity": 0, } @@ -39,7 +39,7 @@ def task_format_check() -> dict[str, Any]: def task_type_check() -> dict[str, Any]: """Run mypy type checking (uses pyproject.toml configuration).""" return { - "actions": ["uv run mypy src/"], + "actions": ["uv run mypy src/ tools/doit/ "], "title": title_with_actions, "verbosity": 0, } @@ -70,9 +70,9 @@ def task_maintainability() -> dict[str, Any]: def task_check() -> dict[str, Any]: - """Run all checks (format, lint, type check, test).""" + """Run all checks (format, lint, type check, security, spelling, test).""" return { "actions": [success_message], - "task_dep": ["format_check", "lint", "type_check", "test"], + "task_dep": ["format_check", "lint", "type_check", "security", "spell_check", "test"], "title": title_with_actions, } diff --git a/tools/doit/release.py b/tools/doit/release.py index 611c3aca..b9f18629 100644 --- a/tools/doit/release.py +++ b/tools/doit/release.py @@ -1,16 +1,141 @@ """Release-related doit tasks.""" +import json import os import re import subprocess # nosec B404 - subprocess is required for doit tasks import sys -from typing import Any +from typing import TYPE_CHECKING, Any from doit.tools import title_with_actions from rich.console import Console from .base import UV_CACHE_DIR +if TYPE_CHECKING: + from rich.console import Console as ConsoleType + + +def validate_merge_commits(console: "ConsoleType") -> bool: + """Validate that all merge commits follow the required format. + + Returns: + bool: True if all merge commits are valid, False otherwise. + """ + console.print("\n[cyan]Validating merge commit format...[/cyan]") + + # Get merge commits since last tag (or all if no tags) + try: + result = subprocess.run( + ["git", "describe", "--tags", "--abbrev=0"], + capture_output=True, + text=True, + ) + last_tag = result.stdout.strip() if result.returncode == 0 else "" + range_spec = f"{last_tag}..HEAD" if last_tag else "HEAD" + + result = subprocess.run( + ["git", "log", "--merges", "--pretty=format:%h %s", range_spec], + capture_output=True, + text=True, + ) + merge_commits = result.stdout.strip().split("\n") if result.stdout.strip() else [] + + except Exception as e: + console.print(f"[yellow]⚠ Could not check merge commits: {e}[/yellow]") + return True # Don't block on this check + + if not merge_commits or merge_commits == [""]: + console.print("[green]✓ No merge commits to validate.[/green]") + return True + + # Pattern: : (merges PR #XX, closes #YY) or (merges PR #XX) + merge_pattern = re.compile( + r"^[a-f0-9]+\s+(feat|fix|refactor|docs|test|chore|ci|perf):\s.+\s" + r"\(merges PR #\d+(?:, closes #\d+)?\)$" + ) + + invalid_commits = [] + for commit in merge_commits: + if commit and not merge_pattern.match(commit): + invalid_commits.append(commit) + + if invalid_commits: + console.print("[bold red]❌ Invalid merge commit format found:[/bold red]") + for commit in invalid_commits: + console.print(f" [red]{commit}[/red]") + console.print("\n[yellow]Expected format:[/yellow]") + console.print(" : (merges PR #XX, closes #YY)") + console.print(" : (merges PR #XX)") + return False + + console.print("[green]✓ All merge commits follow required format.[/green]") + return True + + +def validate_issue_links(console: "ConsoleType") -> bool: + """Validate that commits (except docs) reference issues. + + Returns: + bool: True if validation passes, False otherwise. + """ + console.print("\n[cyan]Validating issue links in commits...[/cyan]") + + try: + # Get commits since last tag + result = subprocess.run( + ["git", "describe", "--tags", "--abbrev=0"], + capture_output=True, + text=True, + ) + last_tag = result.stdout.strip() if result.returncode == 0 else "" + # If no tags, check last 10 commits + range_spec = f"{last_tag}..HEAD" if last_tag else "HEAD~10..HEAD" + + result = subprocess.run( + ["git", "log", "--pretty=format:%h %s", range_spec], + capture_output=True, + text=True, + ) + commits = result.stdout.strip().split("\n") if result.stdout.strip() else [] + + except Exception as e: + console.print(f"[yellow]⚠ Could not check issue links: {e}[/yellow]") + return True # Don't block on this check + + if not commits or commits == [""]: + console.print("[green]✓ No commits to validate.[/green]") + return True + + issue_pattern = re.compile(r"#\d+") + docs_pattern = re.compile(r"^[a-f0-9]+\s+docs:", re.IGNORECASE) + + commits_without_issues = [] + for commit in commits: + if commit: + # Skip docs commits + if docs_pattern.match(commit): + continue + # Skip merge commits (already validated separately) + if "merge" in commit.lower(): + continue + # Check for issue reference + if not issue_pattern.search(commit): + commits_without_issues.append(commit) + + if commits_without_issues: + console.print("[bold yellow]⚠ Warning: Some commits don't reference issues:[/bold yellow]") + for commit in commits_without_issues[:5]: # Show first 5 + console.print(f" [yellow]{commit}[/yellow]") + if len(commits_without_issues) > 5: + console.print(f" [dim]...and {len(commits_without_issues) - 5} more[/dim]") + console.print("\n[dim]This is a warning only - release can continue.[/dim]") + console.print("[dim]Consider linking commits to issues for better traceability.[/dim]") + else: + console.print("[green]✓ All non-docs commits reference issues.[/green]") + + return True # Warning only, don't block release + def task_release_dev(type: str = "alpha") -> dict[str, Any]: """Create a pre-release (alpha/beta) tag for TestPyPI and push to GitHub. @@ -35,12 +160,12 @@ def create_dev_release() -> None: ).stdout.strip() if current_branch != "main": console.print( - f"[bold yellow]\u26a0 Warning: Not on main branch " + f"[bold yellow]⚠ Warning: Not on main branch " f"(currently on {current_branch})[/bold yellow]" ) response = input("Continue anyway? (y/N) ").strip().lower() if response != "y": - console.print("[bold red]\u274c Release cancelled.[/bold red]") + console.print("[bold red]❌ Release cancelled.[/bold red]") sys.exit(1) # Check for uncommitted changes @@ -51,9 +176,7 @@ def create_dev_release() -> None: check=True, ).stdout.strip() if status: - console.print( - "[bold red]\u274c Error: Uncommitted changes detected.[/bold red]" - ) + console.print("[bold red]❌ Error: Uncommitted changes detected.[/bold red]") console.print(status) sys.exit(1) @@ -61,9 +184,9 @@ def create_dev_release() -> None: console.print("\n[cyan]Pulling latest changes...[/cyan]") try: subprocess.run(["git", "pull"], check=True, capture_output=True, text=True) - console.print("[green]\u2713 Git pull successful.[/green]") + console.print("[green]✓ Git pull successful.[/green]") except subprocess.CalledProcessError as e: - console.print("[bold red]\u274c Error pulling latest changes:[/bold red]") + console.print("[bold red]❌ Error pulling latest changes:[/bold red]") console.print(f"[red]Stdout: {e.stdout}[/red]") console.print(f"[red]Stderr: {e.stderr}[/red]") sys.exit(1) @@ -71,13 +194,11 @@ def create_dev_release() -> None: # Run checks console.print("\n[cyan]Running all pre-release checks...[/cyan]") try: - subprocess.run( - ["doit", "check"], check=True, capture_output=True, text=True - ) - console.print("[green]\u2713 All checks passed.[/green]") + subprocess.run(["doit", "check"], check=True, capture_output=True, text=True) + console.print("[green]✓ All checks passed.[/green]") except subprocess.CalledProcessError as e: console.print( - "[bold red]\u274c Pre-release checks failed! " + "[bold red]❌ Pre-release checks failed! " "Please fix issues before tagging.[/bold red]" ) console.print(f"[red]Stdout: {e.stdout}[/red]") @@ -85,9 +206,7 @@ def create_dev_release() -> None: sys.exit(1) # Automated version bump and tagging - console.print( - f"\n[cyan]Bumping version ({type}) and updating changelog...[/cyan]" - ) + console.print(f"\n[cyan]Bumping version ({type}) and updating changelog...[/cyan]") try: # Use cz bump --prerelease --changelog result = subprocess.run( @@ -97,16 +216,14 @@ def create_dev_release() -> None: capture_output=True, text=True, ) - console.print(f"[green]\u2713 Version bumped to {type}.[/green]") + console.print(f"[green]✓ Version bumped to {type}.[/green]") console.print(f"[dim]{result.stdout}[/dim]") # Extract new version - version_match = re.search( - r"Bumping to version (\d+\.\d+\.\d+[^\s]*)", result.stdout - ) + version_match = re.search(r"Bumping to version (\d+\.\d+\.\d+[^\s]*)", result.stdout) new_version = version_match.group(1) if version_match else "unknown" except subprocess.CalledProcessError as e: - console.print("[bold red]\u274c commitizen bump failed![/bold red]") + console.print("[bold red]❌ commitizen bump failed![/bold red]") console.print(f"[red]Stdout: {e.stdout}[/red]") console.print(f"[red]Stderr: {e.stderr}[/red]") sys.exit(1) @@ -119,22 +236,18 @@ def create_dev_release() -> None: capture_output=True, text=True, ) - console.print("[green]\u2713 Tags pushed to origin.[/green]") + console.print("[green]✓ Tags pushed to origin.[/green]") except subprocess.CalledProcessError as e: - console.print("[bold red]\u274c Error pushing tag to origin:[/bold red]") + console.print("[bold red]❌ Error pushing tag to origin:[/bold red]") console.print(f"[red]Stdout: {e.stdout}[/red]") console.print(f"[red]Stderr: {e.stderr}[/red]") sys.exit(1) console.print("\n" + "=" * 70) - console.print( - f"[bold green]\u2713 Development release {new_version} complete![/bold green]" - ) + console.print(f"[bold green]✓ Development release {new_version} complete![/bold green]") console.print("=" * 70) console.print("\nNext steps:") - console.print( - "1. Monitor GitHub Actions (testpypi.yml) for the TestPyPI publish." - ) + console.print("1. Monitor GitHub Actions (testpypi.yml) for the TestPyPI publish.") console.print("2. Verify on TestPyPI once the workflow completes.") return { @@ -152,8 +265,12 @@ def create_dev_release() -> None: } -def task_release() -> dict[str, Any]: - """Automate release: bump version, update CHANGELOG, and push to GitHub (triggers CI/CD).""" +def task_release(increment: str = "") -> dict[str, Any]: + """Automate release: bump version, update CHANGELOG, and push to GitHub (triggers CI/CD). + + Args: + increment (str): Force version increment type (MAJOR, MINOR, PATCH). Auto-detects if empty. + """ def automated_release() -> None: console = Console() @@ -171,12 +288,12 @@ def automated_release() -> None: ).stdout.strip() if current_branch != "main": console.print( - f"[bold yellow]\u26a0 Warning: Not on main branch " + f"[bold yellow]⚠ Warning: Not on main branch " f"(currently on {current_branch})[/bold yellow]" ) response = input("Continue anyway? (y/N) ").strip().lower() if response != "y": - console.print("[bold red]\u274c Release cancelled.[/bold red]") + console.print("[bold red]❌ Release cancelled.[/bold red]") sys.exit(1) # Check for uncommitted changes @@ -187,9 +304,7 @@ def automated_release() -> None: check=True, ).stdout.strip() if status: - console.print( - "[bold red]\u274c Error: Uncommitted changes detected.[/bold red]" - ) + console.print("[bold red]❌ Error: Uncommitted changes detected.[/bold red]") console.print(status) sys.exit(1) @@ -197,23 +312,36 @@ def automated_release() -> None: console.print("\n[cyan]Pulling latest changes...[/cyan]") try: subprocess.run(["git", "pull"], check=True, capture_output=True, text=True) - console.print("[green]\u2713 Git pull successful.[/green]") + console.print("[green]✓ Git pull successful.[/green]") except subprocess.CalledProcessError as e: - console.print("[bold red]\u274c Error pulling latest changes:[/bold red]") + console.print("[bold red]❌ Error pulling latest changes:[/bold red]") console.print(f"[red]Stdout: {e.stdout}[/red]") console.print(f"[red]Stderr: {e.stderr}[/red]") sys.exit(1) + # Governance validation + console.print("\n[bold cyan]Running governance validations...[/bold cyan]") + + # Validate merge commit format (blocking) + if not validate_merge_commits(console): + console.print("\n[bold red]❌ Merge commit validation failed![/bold red]") + console.print("[yellow]Please ensure all merge commits follow the format:[/yellow]") + console.print("[yellow] : (merges PR #XX, closes #YY)[/yellow]") + sys.exit(1) + + # Validate issue links (warning only) + validate_issue_links(console) + + console.print("[bold green]✓ Governance validations complete.[/bold green]") + # Run all checks console.print("\n[cyan]Running all pre-release checks...[/cyan]") try: - subprocess.run( - ["doit", "check"], check=True, capture_output=True, text=True - ) - console.print("[green]\u2713 All checks passed.[/green]") + subprocess.run(["doit", "check"], check=True, capture_output=True, text=True) + console.print("[green]✓ All checks passed.[/green]") except subprocess.CalledProcessError as e: console.print( - "[bold red]\u274c Pre-release checks failed! " + "[bold red]❌ Pre-release checks failed! " "Please fix issues before releasing.[/bold red]" ) console.print(f"[red]Stdout: {e.stdout}[/red]") @@ -221,34 +349,34 @@ def automated_release() -> None: sys.exit(1) # Automated version bump and CHANGELOG generation using commitizen - console.print( - "\n[cyan]Bumping version and generating CHANGELOG with commitizen...[/cyan]" - ) + console.print("\n[cyan]Bumping version and generating CHANGELOG with commitizen...[/cyan]") try: # Use cz bump --changelog --merge-prerelease to update version, # changelog, commit, and tag. This consolidates pre-release changes # into the final release entry + bump_cmd = ["uv", "run", "cz", "bump", "--changelog", "--merge-prerelease"] + if increment: + bump_cmd.extend(["--increment", increment.upper()]) + console.print(f"[dim]Forcing {increment.upper()} version bump[/dim]") result = subprocess.run( - ["uv", "run", "cz", "bump", "--changelog", "--merge-prerelease"], + bump_cmd, env={**os.environ, "UV_CACHE_DIR": UV_CACHE_DIR}, check=True, capture_output=True, text=True, ) console.print( - "[green]\u2713 Version bumped and CHANGELOG updated (merged pre-releases).[/green]" + "[green]✓ Version bumped and CHANGELOG updated (merged pre-releases).[/green]" ) console.print(f"[dim]{result.stdout}[/dim]") # Extract new version from cz output (example: "Bumping to version 1.0.0") - version_match = re.search( - r"Bumping to version (\d+\.\d+\.\d+)", result.stdout - ) + version_match = re.search(r"Bumping to version (\d+\.\d+\.\d+)", result.stdout) # Fallback to "unknown" if regex fails new_version = version_match.group(1) if version_match else "unknown" except subprocess.CalledProcessError as e: console.print( - "[bold red]\u274c commitizen bump failed! " + "[bold red]❌ commitizen bump failed! " "Ensure your commit history is conventional.[/bold red]" ) console.print(f"[red]Stdout: {e.stdout}[/red]") @@ -256,7 +384,7 @@ def automated_release() -> None: sys.exit(1) except Exception as e: console.print( - f"[bold red]\u274c An unexpected error occurred during commitizen bump: {e}[/bold red]" + f"[bold red]❌ An unexpected error occurred during commitizen bump: {e}[/bold red]" ) sys.exit(1) @@ -269,31 +397,438 @@ def automated_release() -> None: capture_output=True, text=True, ) - console.print( - "[green]\u2713 Pushed new commits and tags to GitHub.[/green]" - ) + console.print("[green]✓ Pushed new commits and tags to GitHub.[/green]") except subprocess.CalledProcessError as e: - console.print("[bold red]\u274c Error pushing to GitHub:[/bold red]") + console.print("[bold red]❌ Error pushing to GitHub:[/bold red]") console.print(f"[red]Stdout: {e.stdout}[/red]") console.print(f"[red]Stderr: {e.stderr}[/red]") sys.exit(1) console.print("\n" + "=" * 70) - console.print( - f"[bold green]\u2713 Automated release {new_version} complete![/bold green]" - ) + console.print(f"[bold green]✓ Automated release {new_version} complete![/bold green]") console.print("=" * 70) console.print("\nNext steps:") console.print("1. Monitor GitHub Actions for build and publish.") console.print( - "2. Check TestPyPI: [link=https://test.pypi.org/project/bastproxy/]https://test.pypi.org/project/bastproxy/[/link]" + "2. Check TestPyPI: [link=https://test.pypi.org/project/package-name/]https://test.pypi.org/project/package-name/[/link]" ) console.print( - "3. Check PyPI: [link=https://pypi.org/project/bastproxy/]https://pypi.org/project/bastproxy/[/link]" + "3. Check PyPI: [link=https://pypi.org/project/package-name/]https://pypi.org/project/package-name/[/link]" ) console.print("4. Verify the updated CHANGELOG.md in the repository.") return { "actions": [automated_release], + "params": [ + { + "name": "increment", + "short": "i", + "long": "increment", + "default": "", + "help": "Force increment (MAJOR, MINOR, PATCH). Auto-detects if empty.", + } + ], + "title": title_with_actions, + } + + +def task_release_pr(increment: str = "") -> dict[str, Any]: + """Create a release PR with changelog updates (PR-based workflow). + + This task creates a release branch, updates the changelog, and opens a PR. + After the PR is merged, use `doit release_tag` to tag the release. + + Args: + increment (str): Force version increment type (MAJOR, MINOR, PATCH). Auto-detects if empty. + """ + + def create_release_pr() -> None: + console = Console() + console.print("=" * 70) + console.print("[bold green]Starting PR-based release process...[/bold green]") + console.print("=" * 70) + console.print() + + # Check if on main branch + current_branch = subprocess.run( + ["git", "branch", "--show-current"], + capture_output=True, + text=True, + check=True, + ).stdout.strip() + if current_branch != "main": + console.print( + f"[bold red]❌ Error: Must be on main branch " + f"(currently on {current_branch})[/bold red]" + ) + sys.exit(1) + + # Check for uncommitted changes + status = subprocess.run( + ["git", "status", "-s"], + capture_output=True, + text=True, + check=True, + ).stdout.strip() + if status: + console.print("[bold red]❌ Error: Uncommitted changes detected.[/bold red]") + console.print(status) + sys.exit(1) + + # Pull latest changes + console.print("\n[cyan]Pulling latest changes...[/cyan]") + try: + subprocess.run(["git", "pull"], check=True, capture_output=True, text=True) + console.print("[green]✓ Git pull successful.[/green]") + except subprocess.CalledProcessError as e: + console.print("[bold red]❌ Error pulling latest changes:[/bold red]") + console.print(f"[red]Stdout: {e.stdout}[/red]") + console.print(f"[red]Stderr: {e.stderr}[/red]") + sys.exit(1) + + # Governance validation + console.print("\n[bold cyan]Running governance validations...[/bold cyan]") + + # Validate merge commit format (blocking) + if not validate_merge_commits(console): + console.print("\n[bold red]❌ Merge commit validation failed![/bold red]") + console.print("[yellow]Please ensure all merge commits follow the format:[/yellow]") + console.print("[yellow] : (merges PR #XX, closes #YY)[/yellow]") + sys.exit(1) + + # Validate issue links (warning only) + validate_issue_links(console) + + console.print("[bold green]✓ Governance validations complete.[/bold green]") + + # Run all checks + console.print("\n[cyan]Running all pre-release checks...[/cyan]") + try: + subprocess.run(["doit", "check"], check=True, capture_output=True, text=True) + console.print("[green]✓ All checks passed.[/green]") + except subprocess.CalledProcessError as e: + console.print( + "[bold red]❌ Pre-release checks failed! " + "Please fix issues before releasing.[/bold red]" + ) + console.print(f"[red]Stdout: {e.stdout}[/red]") + console.print(f"[red]Stderr: {e.stderr}[/red]") + sys.exit(1) + + # Get next version using commitizen + console.print("\n[cyan]Determining next version...[/cyan]") + try: + get_next_cmd = ["uv", "run", "cz", "bump", "--get-next"] + if increment: + get_next_cmd.extend(["--increment", increment.upper()]) + console.print(f"[dim]Forcing {increment.upper()} version bump[/dim]") + result = subprocess.run( + get_next_cmd, + env={**os.environ, "UV_CACHE_DIR": UV_CACHE_DIR}, + check=True, + capture_output=True, + text=True, + ) + next_version = result.stdout.strip() + console.print(f"[green]✓ Next version: {next_version}[/green]") + except subprocess.CalledProcessError as e: + console.print("[bold red]❌ Failed to determine next version.[/bold red]") + console.print(f"[red]Stdout: {e.stdout}[/red]") + console.print(f"[red]Stderr: {e.stderr}[/red]") + sys.exit(1) + + # Create release branch + branch_name = f"release/v{next_version}" + console.print(f"\n[cyan]Creating branch {branch_name}...[/cyan]") + try: + subprocess.run( + ["git", "checkout", "-b", branch_name], + check=True, + capture_output=True, + text=True, + ) + console.print(f"[green]✓ Created branch {branch_name}[/green]") + except subprocess.CalledProcessError as e: + console.print(f"[bold red]❌ Failed to create branch {branch_name}.[/bold red]") + console.print(f"[red]Stderr: {e.stderr}[/red]") + sys.exit(1) + + # Update changelog + console.print("\n[cyan]Updating CHANGELOG.md...[/cyan]") + try: + changelog_cmd = ["uv", "run", "cz", "changelog", "--incremental"] + subprocess.run( + changelog_cmd, + env={**os.environ, "UV_CACHE_DIR": UV_CACHE_DIR}, + check=True, + capture_output=True, + text=True, + ) + console.print("[green]✓ CHANGELOG.md updated.[/green]") + except subprocess.CalledProcessError as e: + console.print("[bold red]❌ Failed to update changelog.[/bold red]") + console.print(f"[red]Stdout: {e.stdout}[/red]") + console.print(f"[red]Stderr: {e.stderr}[/red]") + # Cleanup: go back to main + subprocess.run(["git", "checkout", "main"], capture_output=True) + subprocess.run(["git", "branch", "-D", branch_name], capture_output=True) + sys.exit(1) + + # Commit changelog + console.print("\n[cyan]Committing changelog...[/cyan]") + try: + subprocess.run( + ["git", "add", "CHANGELOG.md"], + check=True, + capture_output=True, + text=True, + ) + subprocess.run( + ["git", "commit", "-m", f"chore: update changelog for v{next_version}"], + check=True, + capture_output=True, + text=True, + ) + console.print("[green]✓ Changelog committed.[/green]") + except subprocess.CalledProcessError as e: + console.print("[bold red]❌ Failed to commit changelog.[/bold red]") + console.print(f"[red]Stderr: {e.stderr}[/red]") + # Cleanup + subprocess.run(["git", "checkout", "main"], capture_output=True) + subprocess.run(["git", "branch", "-D", branch_name], capture_output=True) + sys.exit(1) + + # Push branch + console.print(f"\n[cyan]Pushing branch {branch_name}...[/cyan]") + try: + subprocess.run( + ["git", "push", "-u", "origin", branch_name], + check=True, + capture_output=True, + text=True, + ) + console.print("[green]✓ Branch pushed.[/green]") + except subprocess.CalledProcessError as e: + console.print("[bold red]❌ Failed to push branch.[/bold red]") + console.print(f"[red]Stderr: {e.stderr}[/red]") + sys.exit(1) + + # Create PR using doit pr + console.print("\n[cyan]Creating pull request...[/cyan]") + try: + pr_title = f"release: v{next_version}" + pr_body = f"""## Description +Release v{next_version} + +## Type of Change +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (would cause existing functionality to not work as expected) +- [ ] Documentation update +- [x] Release + +## Changes Made +- Updated CHANGELOG.md for v{next_version} + +## Testing +- [ ] All existing tests pass + +## Checklist +- [x] My changes generate no new warnings + +## Additional Notes +After this PR is merged, run `doit release_tag` to create the version tag +and trigger the release workflow. +""" + # Use gh CLI directly since we're in a non-interactive context + subprocess.run( + [ + "gh", + "pr", + "create", + "--title", + pr_title, + "--body", + pr_body, + ], + check=True, + capture_output=True, + text=True, + ) + console.print("[green]✓ Pull request created.[/green]") + except subprocess.CalledProcessError as e: + console.print("[bold red]❌ Failed to create PR.[/bold red]") + console.print(f"[red]Stderr: {e.stderr}[/red]") + sys.exit(1) + + console.print("\n" + "=" * 70) + console.print(f"[bold green]✓ Release PR for v{next_version} created![/bold green]") + console.print("=" * 70) + console.print("\nNext steps:") + console.print("1. Review and merge the PR.") + console.print("2. After merge, run: doit release_tag") + + return { + "actions": [create_release_pr], + "params": [ + { + "name": "increment", + "short": "i", + "long": "increment", + "default": "", + "help": "Force increment (MAJOR, MINOR, PATCH). Auto-detects if empty.", + } + ], + "title": title_with_actions, + } + + +def task_release_tag() -> dict[str, Any]: + """Tag the release after a release PR is merged. + + This task finds the most recently merged release PR, extracts the version, + creates a git tag, and pushes it to trigger the release workflow. + """ + + def create_release_tag() -> None: + console = Console() + console.print("=" * 70) + console.print("[bold green]Creating release tag...[/bold green]") + console.print("=" * 70) + console.print() + + # Check if on main branch + current_branch = subprocess.run( + ["git", "branch", "--show-current"], + capture_output=True, + text=True, + check=True, + ).stdout.strip() + if current_branch != "main": + console.print( + f"[bold red]❌ Error: Must be on main branch " + f"(currently on {current_branch})[/bold red]" + ) + sys.exit(1) + + # Pull latest changes + console.print("\n[cyan]Pulling latest changes...[/cyan]") + try: + subprocess.run(["git", "pull"], check=True, capture_output=True, text=True) + console.print("[green]✓ Git pull successful.[/green]") + except subprocess.CalledProcessError as e: + console.print("[bold red]❌ Error pulling latest changes:[/bold red]") + console.print(f"[red]Stderr: {e.stderr}[/red]") + sys.exit(1) + + # Find the most recently merged release PR + console.print("\n[cyan]Finding merged release PR...[/cyan]") + try: + result = subprocess.run( + [ + "gh", + "pr", + "list", + "--state", + "merged", + "--search", + "release: v in:title", + "--limit", + "1", + "--json", + "title,mergedAt,headRefName", + ], + check=True, + capture_output=True, + text=True, + ) + prs = json.loads(result.stdout) + if not prs: + console.print("[bold red]❌ No merged release PR found.[/bold red]") + console.print( + "[yellow]Ensure a release PR with title 'release: vX.Y.Z' was merged.[/yellow]" + ) + sys.exit(1) + + pr = prs[0] + pr_title = pr["title"] + branch_name = pr["headRefName"] + + # Extract version from PR title (format: "release: vX.Y.Z") + version_match = re.search(r"release:\s*v?(\d+\.\d+\.\d+)", pr_title) + if not version_match: + # Try extracting from branch name (format: "release/vX.Y.Z") + version_match = re.search(r"release/v?(\d+\.\d+\.\d+)", branch_name) + + if not version_match: + console.print("[bold red]❌ Could not extract version from PR.[/bold red]") + console.print(f"[yellow]PR title: {pr_title}[/yellow]") + console.print(f"[yellow]Branch: {branch_name}[/yellow]") + sys.exit(1) + + version = version_match.group(1) + tag_name = f"v{version}" + console.print(f"[green]✓ Found release PR: {pr_title}[/green]") + console.print(f"[green]✓ Version to tag: {tag_name}[/green]") + + except subprocess.CalledProcessError as e: + console.print("[bold red]❌ Failed to find release PR.[/bold red]") + console.print(f"[red]Stderr: {e.stderr}[/red]") + sys.exit(1) + + # Check if tag already exists + existing_tags = subprocess.run( + ["git", "tag", "-l", tag_name], + capture_output=True, + text=True, + ).stdout.strip() + if existing_tags: + console.print(f"[bold red]❌ Tag {tag_name} already exists.[/bold red]") + sys.exit(1) + + # Create tag + console.print(f"\n[cyan]Creating tag {tag_name}...[/cyan]") + try: + subprocess.run( + ["git", "tag", tag_name], + check=True, + capture_output=True, + text=True, + ) + console.print(f"[green]✓ Tag {tag_name} created.[/green]") + except subprocess.CalledProcessError as e: + console.print("[bold red]❌ Failed to create tag.[/bold red]") + console.print(f"[red]Stderr: {e.stderr}[/red]") + sys.exit(1) + + # Push tag + console.print(f"\n[cyan]Pushing tag {tag_name}...[/cyan]") + try: + subprocess.run( + ["git", "push", "origin", tag_name], + check=True, + capture_output=True, + text=True, + ) + console.print(f"[green]✓ Tag {tag_name} pushed.[/green]") + except subprocess.CalledProcessError as e: + console.print("[bold red]❌ Failed to push tag.[/bold red]") + console.print(f"[red]Stderr: {e.stderr}[/red]") + sys.exit(1) + + console.print("\n" + "=" * 70) + console.print(f"[bold green]✓ Release {tag_name} tagged![/bold green]") + console.print("=" * 70) + console.print("\nNext steps:") + console.print("1. Monitor GitHub Actions for build and publish.") + console.print( + "2. Check TestPyPI: [link=https://test.pypi.org/project/package-name/]https://test.pypi.org/project/package-name/[/link]" + ) + console.print( + "3. Check PyPI: [link=https://pypi.org/project/package-name/]https://pypi.org/project/package-name/[/link]" + ) + + return { + "actions": [create_release_tag], "title": title_with_actions, } diff --git a/tools/doit/security.py b/tools/doit/security.py index 4d6b4e8e..6c10a3b3 100644 --- a/tools/doit/security.py +++ b/tools/doit/security.py @@ -20,7 +20,7 @@ def task_security() -> dict[str, Any]: """Run security checks with bandit (requires security extras).""" return { "actions": [ - "uv run bandit -c pyproject.toml -r src/ || " + "uv run bandit -c pyproject.toml -r src/ tools/ bootstrap.py || " "echo 'bandit not installed. Run: uv sync --extra security'" ], "title": title_with_actions, diff --git a/tools/doit/template_clean.py b/tools/doit/template_clean.py new file mode 100644 index 00000000..a2ff110c --- /dev/null +++ b/tools/doit/template_clean.py @@ -0,0 +1,114 @@ +"""Template cleanup doit task. + +Provides a task to remove template-specific files from projects +created from pyproject-template. +""" + +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Any + +from doit.tools import title_with_actions +from rich.console import Console + +# Add tools directory to path for imports +_tools_dir = Path(__file__).parent.parent +if str(_tools_dir) not in sys.path: + sys.path.insert(0, str(_tools_dir)) + + +def task_template_clean() -> dict[str, Any]: + """Remove template-specific files from the project. + + Options: + --setup: Remove setup files only (keep update checking) + --all: Remove all template files (no future updates) + --dry-run: Show what would be deleted without deleting + """ + + def run_cleanup(setup: bool, all_files: bool, dry_run: bool) -> None: + # Import from tools directory + cleanup_module_path = _tools_dir / "pyproject_template" / "cleanup.py" + if not cleanup_module_path.exists(): + console = Console() + console.print( + "[red]Error: cleanup.py not found. Template files may have been removed.[/red]" + ) + sys.exit(1) + + import importlib.util + + spec = importlib.util.spec_from_file_location("cleanup", cleanup_module_path) + if spec is None or spec.loader is None: + console = Console() + console.print("[red]Error: Could not load cleanup module.[/red]") + sys.exit(1) + + cleanup_mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(cleanup_mod) + + cleanup_mode_enum = cleanup_mod.CleanupMode + cleanup_template_files = cleanup_mod.cleanup_template_files + prompt_cleanup = cleanup_mod.prompt_cleanup + + console = Console() + + # Determine mode + if setup and all_files: + console.print("[red]Error: Cannot specify both --setup and --all[/red]") + sys.exit(1) + elif setup: + mode = cleanup_mode_enum.SETUP_ONLY + elif all_files: + mode = cleanup_mode_enum.ALL + else: + # Interactive mode + mode = prompt_cleanup() + if mode is None: + console.print("[cyan]Keeping all template files[/cyan]") + return + + # Perform cleanup + result = cleanup_template_files(mode, dry_run=dry_run) + + if dry_run: + console.print() + console.print("[yellow]Dry run complete. No files were deleted.[/yellow]") + elif result.failed: + console.print() + console.print("[red]Some files could not be deleted.[/red]") + sys.exit(1) + + return { + "actions": [run_cleanup], + "params": [ + { + "name": "setup", + "short": "s", + "long": "setup", + "type": bool, + "default": False, + "help": "Remove setup files only (keep update checking)", + }, + { + "name": "all_files", + "short": "a", + "long": "all", + "type": bool, + "default": False, + "help": "Remove all template files (no future updates)", + }, + { + "name": "dry_run", + "short": "n", + "long": "dry-run", + "type": bool, + "default": False, + "help": "Show what would be deleted without deleting", + }, + ], + "title": title_with_actions, + "verbosity": 2, + } diff --git a/tools/doit/templates.py b/tools/doit/templates.py new file mode 100644 index 00000000..04fd1e17 --- /dev/null +++ b/tools/doit/templates.py @@ -0,0 +1,359 @@ +"""Template parser for GitHub issue and PR templates, and ADR templates. + +Reads templates from .github/ directory and docs/decisions/ and converts them +to editor-friendly markdown. +""" + +from __future__ import annotations + +import re +from functools import lru_cache +from pathlib import Path +from typing import Any, NamedTuple + +import yaml + + +class AdrTemplate(NamedTuple): + """Parsed ADR template with editor content and metadata.""" + + editor_template: str + required_sections: list[str] + all_sections: list[str] + + +class IssueTemplate(NamedTuple): + """Parsed issue template with editor content and metadata.""" + + name: str + labels: str + editor_template: str + required_sections: list[str] + + +# Map issue types to template filenames +ISSUE_TYPE_TO_FILE = { + "feature": "feature_request.yml", + "bug": "bug_report.yml", + "refactor": "refactor.yml", + "doc": "documentation.yml", + "chore": "chore.yml", +} + +# Map YAML field IDs to section names for validation +FIELD_ID_TO_SECTION = { + "problem": "Problem", + "proposed-solution": "Proposed Solution", + "success-criteria": "Success Criteria", + "additional-context": "Additional Context", + "bug-description": "Bug Description", + "steps-to-reproduce": "Steps to Reproduce", + "expected-vs-actual": "Expected vs Actual Behavior", + "environment": "Environment", + "error-output": "Error Output", + "current-code-issue": "Current Code Issue", + "proposed-improvement": "Proposed Improvement", + "doc-type": "Documentation Type", + "description": "Description", + "location": "Suggested Location", + "chore-type": "Chore Type", + "proposed-changes": "Proposed Changes", +} + + +def _get_github_dir() -> Path: + """Get the .github directory path.""" + # Find project root by looking for .github directory + current = Path(__file__).resolve() + for parent in [current, *list(current.parents)]: + github_dir = parent / ".github" + if github_dir.is_dir(): + return github_dir + raise FileNotFoundError("Could not find .github directory") + + +def _parse_yaml_template(template_path: Path) -> dict[str, Any]: + """Parse a YAML template file. + + Args: + template_path: Path to the YAML template file + + Returns: + Parsed YAML content as dict + """ + with open(template_path, encoding="utf-8") as f: + result: dict[str, Any] = yaml.safe_load(f) + return result + + +def _yaml_to_editor_markdown(yaml_data: dict) -> tuple[str, list[str]]: + """Convert YAML template structure to editor-friendly markdown. + + Args: + yaml_data: Parsed YAML template data + + Returns: + Tuple of (editor_template, required_sections) + """ + lines = [ + "# Lines starting with # are comments and will be ignored.", + "# Fill in the sections below, save, and exit.", + "# Delete the placeholder text and add your content.", + "", + ] + required_sections: list[str] = [] + + body = yaml_data.get("body", []) + for item in body: + item_type = item.get("type") + attrs = item.get("attributes", {}) + validations = item.get("validations", {}) + + if item_type == "markdown": + # Skip markdown intro sections + continue + + if item_type in ("textarea", "dropdown"): + field_id = item.get("id", "") + label = attrs.get("label", "") + description = attrs.get("description", "") + placeholder = attrs.get("placeholder", "") + is_required = validations.get("required", False) + + # Map field ID to section name + section_name = FIELD_ID_TO_SECTION.get(field_id) or label + + if is_required and section_name: + required_sections.append(section_name) + + # Build section + lines.append(f"## {section_name}") + + # Add requirement indicator in comment + req_text = "Required" if is_required else "Optional" + if description: + lines.append(f"") + else: + lines.append(f"") + + # Add placeholder content + if item_type == "dropdown": + # For dropdowns, list options as choices + options = attrs.get("options", []) + lines.append(" / ".join(options)) + elif placeholder: + # Clean up placeholder - remove leading pipe formatting + placeholder_clean = placeholder.strip() + lines.append(placeholder_clean) + else: + lines.append("") + + lines.append("") + + return "\n".join(lines), required_sections + + +@lru_cache(maxsize=10) +def get_issue_template(issue_type: str) -> IssueTemplate: + """Get the issue template for a given type. + + Args: + issue_type: One of 'feature', 'bug', 'refactor', 'doc', 'chore' + + Returns: + IssueTemplate with editor content and metadata + + Raises: + ValueError: If issue_type is not valid + FileNotFoundError: If template file doesn't exist + """ + if issue_type not in ISSUE_TYPE_TO_FILE: + valid_types = list(ISSUE_TYPE_TO_FILE.keys()) + raise ValueError(f"Invalid issue type: {issue_type}. Must be one of: {valid_types}") + + github_dir = _get_github_dir() + template_file = github_dir / "ISSUE_TEMPLATE" / ISSUE_TYPE_TO_FILE[issue_type] + + if not template_file.exists(): + raise FileNotFoundError(f"Template file not found: {template_file}") + + yaml_data = _parse_yaml_template(template_file) + + # Extract metadata + name = yaml_data.get("name", issue_type) + labels_list = yaml_data.get("labels", []) + labels = ",".join(labels_list) if isinstance(labels_list, list) else str(labels_list) + + # Convert to editor template + editor_template, required_sections = _yaml_to_editor_markdown(yaml_data) + + return IssueTemplate( + name=name, + labels=labels, + editor_template=editor_template, + required_sections=required_sections, + ) + + +@lru_cache(maxsize=1) +def get_pr_template() -> str: + """Get the PR template content. + + Returns: + PR template as markdown string with editor comments added + + Raises: + FileNotFoundError: If PR template doesn't exist + """ + github_dir = _get_github_dir() + template_file = github_dir / "pull_request_template.md" + + if not template_file.exists(): + raise FileNotFoundError(f"PR template not found: {template_file}") + + content = template_file.read_text(encoding="utf-8") + + # Add editor instructions at the top + header = """\ +# Lines starting with # are comments and will be ignored. +# Fill in the sections below, save, and exit. +# Delete the placeholder text and add your content. +# Mark checkboxes with [x] where applicable. + +""" + return header + content + + +def get_issue_labels(issue_type: str) -> str: + """Get the labels for a given issue type. + + Args: + issue_type: One of 'feature', 'bug', 'refactor', 'doc', 'chore' + + Returns: + Comma-separated string of labels + """ + template = get_issue_template(issue_type) + return template.labels + + +def get_required_sections(issue_type: str) -> list[str]: + """Get the required sections for a given issue type. + + Args: + issue_type: One of 'feature', 'bug', 'refactor', 'doc', 'chore' + + Returns: + List of required section names + """ + template = get_issue_template(issue_type) + return template.required_sections + + +def clear_template_cache() -> None: + """Clear the template cache. Useful for testing.""" + get_issue_template.cache_clear() + get_pr_template.cache_clear() + get_adr_template.cache_clear() + + +def _get_docs_dir() -> Path: + """Get the docs directory path.""" + # Find project root by looking for docs directory + current = Path(__file__).resolve() + for parent in [current, *list(current.parents)]: + docs_dir = parent / "docs" + if docs_dir.is_dir(): + return docs_dir + raise FileNotFoundError("Could not find docs directory") + + +def _parse_adr_template(template_path: Path) -> AdrTemplate: + """Parse an ADR markdown template file. + + Extracts section headers and identifies required sections marked with + comments. + + Args: + template_path: Path to the ADR template file + + Returns: + AdrTemplate with editor content and metadata + """ + content = template_path.read_text(encoding="utf-8") + + all_sections: list[str] = [] + required_sections: list[str] = [] + + # Find all ## headers and check for marker + # Pattern: ## Section Name followed optionally by + lines = content.split("\n") + i = 0 + while i < len(lines): + line = lines[i] + # Check for ## header (but not ### or more) + header_match = re.match(r"^##\s+(.+)$", line) + if header_match: + section_name = header_match.group(1).strip() + all_sections.append(section_name) + + # Check next line for marker + if i + 1 < len(lines): + next_line = lines[i + 1].strip() + if next_line == "": + required_sections.append(section_name) + i += 1 + + # Create editor template with instructions + editor_header = """\ +# Lines starting with # are comments and will be ignored. +# Fill in the sections below, save, and exit. +# Sections marked must have content. + +""" + editor_template = editor_header + content + + return AdrTemplate( + editor_template=editor_template, + required_sections=required_sections, + all_sections=all_sections, + ) + + +@lru_cache(maxsize=1) +def get_adr_template() -> AdrTemplate: + """Get the ADR template with parsed metadata. + + Returns: + AdrTemplate with editor content and required sections + + Raises: + FileNotFoundError: If template file doesn't exist + """ + docs_dir = _get_docs_dir() + template_file = docs_dir / "decisions" / "adr-template.md" + + if not template_file.exists(): + raise FileNotFoundError(f"ADR template not found: {template_file}") + + return _parse_adr_template(template_file) + + +def get_adr_required_sections() -> list[str]: + """Get the required sections for ADRs. + + Returns: + List of required section names + """ + template = get_adr_template() + return template.required_sections + + +def get_adr_all_sections() -> list[str]: + """Get all sections defined in the ADR template. + + Returns: + List of all section names + """ + template = get_adr_template() + return template.all_sections diff --git a/tools/doit/testing.py b/tools/doit/testing.py index 947c0087..5cc760ad 100644 --- a/tools/doit/testing.py +++ b/tools/doit/testing.py @@ -19,7 +19,7 @@ def task_coverage() -> dict[str, Any]: return { "actions": [ "uv run pytest " - "--cov=bastproxy --cov-report=term-missing " + "--cov=package_name --cov-report=term-missing " "--cov-report=html:tmp/htmlcov --cov-report=xml:tmp/coverage.xml -v" ], "title": title_with_actions, diff --git a/tools/pyproject_template/repo_settings.py b/tools/pyproject_template/repo_settings.py new file mode 100644 index 00000000..078c728d --- /dev/null +++ b/tools/pyproject_template/repo_settings.py @@ -0,0 +1,476 @@ +#!/usr/bin/env python3 +""" +repo_settings.py - GitHub repository settings configuration. + +This module provides functions to configure GitHub repository settings, +branch protection, labels, GitHub Pages, and CodeQL. It can be used +independently of the full setup process. + +These functions are used by both setup_repo.py (initial setup) and +manage.py (updating existing repositories). +""" + +from __future__ import annotations + +import subprocess # nosec B404 +import sys +from pathlib import Path +from typing import Any + +# Support running as script or as module +_script_dir = Path(__file__).parent +if str(_script_dir) not in sys.path: + sys.path.insert(0, str(_script_dir)) + +from utils import TEMPLATE_REPO, GitHubCLI, Logger # noqa: E402 + + +def configure_repository_settings( + repo_full: str, + description: str, + visibility: str | None = None, + template_repo: str = TEMPLATE_REPO, +) -> bool: + """Configure repository settings to match template. + + Args: + repo_full: Full repository name (owner/repo) + description: Repository description + visibility: Repository visibility ('public' or 'private'), used for + determining which security features are available + template_repo: Template repository to copy settings from + + Returns: + True if successful, False otherwise + """ + Logger.step("Configuring repository settings...") + + try: + # Get ALL settings from template repository + template_settings = GitHubCLI.api(f"repos/{template_repo}") + + # Read-only fields that should not be copied + readonly_fields = { + # URLs + "archive_url", + "assignees_url", + "blobs_url", + "branches_url", + "clone_url", + "collaborators_url", + "comments_url", + "commits_url", + "compare_url", + "contents_url", + "contributors_url", + "deployments_url", + "downloads_url", + "events_url", + "forks_url", + "git_commits_url", + "git_refs_url", + "git_tags_url", + "git_url", + "hooks_url", + "html_url", + "issue_comment_url", + "issue_events_url", + "issues_url", + "keys_url", + "labels_url", + "languages_url", + "merges_url", + "milestones_url", + "notifications_url", + "pulls_url", + "releases_url", + "ssh_url", + "stargazers_url", + "statuses_url", + "subscribers_url", + "subscription_url", + "svn_url", + "tags_url", + "teams_url", + "trees_url", + "url", + # IDs and metadata + "id", + "node_id", + "owner", + "full_name", + "name", + # Timestamps + "created_at", + "updated_at", + "pushed_at", + # Counts and computed values + "forks", + "forks_count", + "open_issues", + "open_issues_count", + "size", + "stargazers_count", + "watchers", + "watchers_count", + "subscribers_count", + "network_count", + # Other read-only + "fork", + "language", + "license", + "permissions", + "disabled", + "mirror_url", + "default_branch", # Keep as main + "private", # Set separately via visibility + "is_template", # Don't make new repos templates + # Deprecated + "use_squash_pr_title_as_default", + } + + # Build settings data by copying all writable fields from template + data: dict[str, Any] = {} + for key, value in template_settings.items(): + if key not in readonly_fields and value is not None: + data[key] = value + + # Override description with user's description + data["description"] = description + + # Check if repository is in an organization + repo_owner = repo_full.split("/")[0] + owner_info = GitHubCLI.api(f"users/{repo_owner}") + is_org = owner_info.get("type") == "Organization" + + # Remove allow_forking if not an org repo (only applies to orgs) + if not is_org and "allow_forking" in data: + data.pop("allow_forking") + + # Remove security_and_analysis - we'll handle it separately + security_settings = data.pop("security_and_analysis", None) + + # Apply all settings in one call + GitHubCLI.api(f"repos/{repo_full}", method="PATCH", data=data) + Logger.success("Repository settings configured") + + # Configure security and analysis settings separately + if security_settings: + _configure_security_settings(repo_full, security_settings, visibility) + + return True + + except subprocess.CalledProcessError as e: + Logger.warning("Repository settings configuration failed") + if e.stderr: + print(f" Error: {e.stderr.strip()}") + Logger.info("You can configure settings manually at:") + Logger.info(f" https://github.com/{repo_full}/settings") + return False + + +def _configure_security_settings( + repo_full: str, + security_settings: dict[str, Any], + visibility: str | None = None, +) -> None: + """Configure security and analysis settings. + + Args: + repo_full: Full repository name (owner/repo) + security_settings: Security settings from template + visibility: Repository visibility for determining feature availability + """ + if not security_settings: + return + + # Enable secret scanning if template has it + if security_settings.get("secret_scanning", {}).get("status") == "enabled": + try: + GitHubCLI.api( + f"repos/{repo_full}/secret-scanning", + method="PATCH", + data={"status": "enabled"}, + ) + Logger.success("Secret scanning enabled") + except subprocess.CalledProcessError as e: + # 404 is expected for free/private repos that don't support this + if "404" in str(e.stderr): + if visibility == "public": + Logger.success("Secret scanning enabled (default for public repos)") + else: + Logger.info("Secret scanning not available (requires GHAS or public repo)") + else: + Logger.warning("Secret scanning configuration failed") + if e.stderr: + print(f" Error: {e.stderr.strip()}") + + # Enable secret scanning push protection if template has it + if security_settings.get("secret_scanning_push_protection", {}).get("status") == "enabled": + try: + GitHubCLI.api( + f"repos/{repo_full}/secret-scanning/push-protection", + method="PATCH", + data={"status": "enabled"}, + ) + Logger.success("Secret scanning push protection enabled") + except subprocess.CalledProcessError as e: + if "404" in str(e.stderr): + if visibility == "public": + Logger.success( + "Secret scanning push protection enabled (default for public repos)" + ) + else: + Logger.info("Secret scanning push protection not available for this repository") + else: + Logger.warning("Secret scanning push protection configuration failed") + if e.stderr: + print(f" Error: {e.stderr.strip()}") + + # Enable Dependabot security updates if template has it + if security_settings.get("dependabot_security_updates", {}).get("status") == "enabled": + try: + GitHubCLI.api( + f"repos/{repo_full}/automated-security-fixes", + method="PUT", + ) + Logger.success("Dependabot security updates enabled") + except subprocess.CalledProcessError as e: + Logger.warning("Dependabot security updates configuration failed") + if e.stderr: + print(f" Error: {e.stderr.strip()}") + + +def configure_branch_protection( + repo_full: str, + template_repo: str = TEMPLATE_REPO, +) -> bool: + """Configure branch protection using rulesets. + + Args: + repo_full: Full repository name (owner/repo) + template_repo: Template repository to copy rulesets from + + Returns: + True if successful, False otherwise + """ + Logger.step("Configuring branch protection rulesets...") + + try: + # Get rulesets from template + template_rulesets = GitHubCLI.api(f"repos/{template_repo}/rulesets") + + if not template_rulesets: + Logger.warning("No rulesets found in template repository") + return True # Not a failure, just nothing to do + + # Get existing rulesets from target repository to check for duplicates + existing_rulesets = GitHubCLI.api(f"repos/{repo_full}/rulesets") + existing_by_name: dict[str, int] = { + ruleset["name"]: ruleset["id"] for ruleset in existing_rulesets + } + + # Replicate each ruleset + for template_ruleset in template_rulesets: + # Get full ruleset details + ruleset_id = template_ruleset["id"] + full_ruleset = GitHubCLI.api(f"repos/{template_repo}/rulesets/{ruleset_id}") + + # Prepare ruleset data (remove read-only fields) + ruleset_data = { + "name": full_ruleset["name"], + "target": full_ruleset["target"], + "enforcement": full_ruleset["enforcement"], + "bypass_actors": full_ruleset.get("bypass_actors", []), + "conditions": full_ruleset.get("conditions", {}), + "rules": full_ruleset.get("rules", []), + } + + ruleset_name = full_ruleset["name"] + + # Check if ruleset already exists + if ruleset_name in existing_by_name: + # Update existing ruleset + existing_id = existing_by_name[ruleset_name] + GitHubCLI.api( + f"repos/{repo_full}/rulesets/{existing_id}", + method="PUT", + data=ruleset_data, + ) + Logger.success(f"Ruleset '{ruleset_name}' updated") + else: + # Create new ruleset + GitHubCLI.api( + f"repos/{repo_full}/rulesets", + method="POST", + data=ruleset_data, + ) + Logger.success(f"Ruleset '{ruleset_name}' created") + + return True + + except subprocess.CalledProcessError as e: + Logger.warning("Branch protection ruleset configuration failed") + if e.stderr: + print(f" Error: {e.stderr.strip()}") + Logger.info("You can configure rulesets manually at:") + Logger.info(f" https://github.com/{repo_full}/settings/rules") + return False + + +def replicate_labels( + repo_full: str, + template_repo: str = TEMPLATE_REPO, +) -> bool: + """Replicate labels from template. + + Args: + repo_full: Full repository name (owner/repo) + template_repo: Template repository to copy labels from + + Returns: + True if successful, False otherwise + """ + Logger.step("Replicating labels from template...") + + try: + # Get labels from template + labels = GitHubCLI.api(f"repos/{template_repo}/labels") + + if not labels: + Logger.warning("Could not retrieve labels from template") + return False + + # Create each label + for label in labels: + try: + label_data = { + "name": label["name"], + "color": label["color"], + "description": label.get("description", ""), + } + GitHubCLI.api( + f"repos/{repo_full}/labels", + method="POST", + data=label_data, + ) + except subprocess.CalledProcessError: + # Label might already exist, skip + pass + + Logger.success("Labels replicated") + return True + + except subprocess.CalledProcessError as e: + Logger.warning("Failed to retrieve labels from template") + if e.stderr: + print(f" Error: {e.stderr.strip()}") + return False + + +def enable_github_pages(repo_full: str) -> bool: + """Enable GitHub Pages. + + Args: + repo_full: Full repository name (owner/repo) + + Returns: + True if successful, False otherwise + """ + Logger.step("Enabling GitHub Pages...") + + try: + data = { + "source": { + "branch": "gh-pages", + "path": "/", + } + } + GitHubCLI.api(f"repos/{repo_full}/pages", method="POST", data=data) + Logger.success("GitHub Pages enabled") + return True + except subprocess.CalledProcessError: + Logger.warning("GitHub Pages not enabled (gh-pages branch doesn't exist yet)") + Logger.info("Pages will be enabled automatically after first docs deployment") + return False + + +def configure_codeql( + repo_full: str, + template_repo: str = TEMPLATE_REPO, +) -> bool: + """Configure CodeQL code scanning to match template. + + Args: + repo_full: Full repository name (owner/repo) + template_repo: Template repository to copy CodeQL config from + + Returns: + True if successful, False otherwise + """ + Logger.step("Configuring CodeQL code scanning...") + + try: + # Get CodeQL setup from template + template_codeql = GitHubCLI.api(f"repos/{template_repo}/code-scanning/default-setup") + + if template_codeql.get("state") != "configured": + Logger.info("CodeQL not configured in template, skipping") + return True # Not a failure, just nothing to do + + # Replicate CodeQL configuration + codeql_data: dict[str, Any] = { + "state": "configured", + "query_suite": template_codeql.get("query_suite", "default"), + } + + # Add languages if specified (will auto-detect if not provided) + if template_codeql.get("languages"): + codeql_data["languages"] = template_codeql["languages"] + + GitHubCLI.api( + f"repos/{repo_full}/code-scanning/default-setup", + method="PATCH", + data=codeql_data, + ) + Logger.success( + f"CodeQL configured with {template_codeql.get('query_suite', 'default')} query suite" + ) + return True + + except subprocess.CalledProcessError as e: + Logger.warning("CodeQL configuration failed") + if e.stderr: + print(f" Error: {e.stderr.strip()}") + Logger.info("You can configure CodeQL manually at:") + Logger.info(f" https://github.com/{repo_full}/security/code-scanning") + return False + + +def update_all_repo_settings( + repo_full: str, + description: str, + visibility: str | None = None, + template_repo: str = TEMPLATE_REPO, +) -> bool: + """Update all repository settings to match template. + + This is a convenience function that runs all configuration steps. + + Args: + repo_full: Full repository name (owner/repo) + description: Repository description + visibility: Repository visibility ('public' or 'private') + template_repo: Template repository to copy settings from + + Returns: + True if all steps successful, False if any failed + """ + results = [ + configure_repository_settings(repo_full, description, visibility, template_repo), + configure_branch_protection(repo_full, template_repo), + replicate_labels(repo_full, template_repo), + enable_github_pages(repo_full), + configure_codeql(repo_full, template_repo), + ] + return all(results)