Loading...
Loading...
Code style and quality rules for Megatron Bridge — ruff configuration, naming conventions, type hints, mypy rules, docstrings, copyright headers, logging, and the code review checklist.
npx skill4agent add nvidia/skills linting-and-formattinguv run ruff check --fix .
uv run ruff format .git add -u
pre-commit run
# if it auto-fixed files:
git add -u
pre-commit runruff.toml| Rule | ID | Description |
|---|---|---|
| Line length | — | 119 characters (formatter) |
| Quote style | — | Double quotes |
| f-string without placeholders | F541 | Error |
| Unused local variable | F841 | Auto-removed by |
| Unused import | F401 | Auto-removed by |
| Ambiguous variable name | E741 | Error (e.g., |
| Undefined name | F821 | Error |
| Block comment format | E266 | Error (too many |
| Import sorting | I | isort-compatible, auto-fixed |
| Public class docstring | D101 | Warning (ignored in test files) |
| Public function docstring | D103 | Warning (ignored in test files) |
__init__.pytest_*.py*_test.pytests/*.py| Kind | Convention | Example |
|---|---|---|
| Files | snake_case | |
| Classes | PascalCase | |
| Functions/methods | snake_case | |
| Local variables | snake_case | |
| Variables starting with digit | prefix | |
| Global variables | UPPER_SNAKE + prefix | |
| Constants | UPPER_SNAKE | |
__future__megatron.coretorchtransformersmegatron.bridge.*IT | NoneOptional[T]X | YUnion[X, Y]listdicttupletypingdef get_module_by_name(
model: torch.nn.Module,
name: str,
default: torch.nn.Module | None = None,
) -> torch.nn.Module | None:
...uv run mypy --strict path/to/file.pyAnyobjectTypeVarOptionalx: int | None = Nonex: int = Nonetyping.cast()TypedDictdict[str, Any]Callable[[ArgType], ReturnType]Protocol# type: ignore[code]*# Don't
def scatter_weights(tensor: Tensor, tp_group: ProcessGroup, ep_group: ProcessGroup): ...
# Do
def scatter_weights(tensor: Tensor, *, tp_group: ProcessGroup, ep_group: ProcessGroup): ...def convert_weights(
source_model: torch.nn.Module,
target_model: torch.nn.Module,
mapping: MegatronParamMapping,
) -> dict[str, torch.Tensor]:
"""Convert weights from source to target model format.
Args:
source_model: The source model containing weights to convert.
target_model: The target model that will receive converted weights.
mapping: Parameter mapping defining the conversion rules.
Returns:
Dictionary mapping parameter names to converted weight tensors.
Raises:
ValueError: If source and target models have incompatible shapes.
"""logging.getLogger(__name__)print_rank_0warn_rank_0# Don't
print(f"Loading weights for {model_name}")
# Do
logger = logging.getLogger(__name__)
logger.info("Loading weights for %s", model_name)try:
state_dict = torch.load(path)
except FileNotFoundError:
raise ValueError(f"Checkpoint not found at {path}") from None
else:
result = convert(state_dict)# Don't
def make_config(*args):
x, y = args
return dict(**locals())
# Do
def make_config(x, y):
return {"x": x, "y": y}dataclassesNamedTuple# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.print()loggerprint_rank_0Any# type: ignore