Skip to content

Commit 300abb3

Browse files
rhiga2pre-commit-ci[bot]Bordaawaelchli
authored
Adding non-layer param count to summary (Lightning-AI#17005)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
1 parent d7b668e commit 300abb3

File tree

4 files changed

+81
-1
lines changed

4 files changed

+81
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5151
- Added CLI option `--map-to-cpu` to the checkpoint upgrade script to enable converting GPU checkpoints on a CPU-only machine ([#17527](https://github.com/Lightning-AI/lightning/pull/17527))
5252

5353

54+
- Added non-layer param count to the model summary ([#17005](https://github.com/Lightning-AI/lightning/pull/17005))
55+
56+
5457
### Changed
5558

5659
- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))

src/lightning/pytorch/utilities/model_summary/model_summary.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
3434
UNKNOWN_SIZE = "?"
35+
LEFTOVER_PARAMS_NAME = "other params"
36+
NOT_APPLICABLE = "n/a"
3537

3638

3739
class LayerSummary:
@@ -141,6 +143,9 @@ class ModelSummary:
141143
intermediate input- and output shapes of all layers. Supported are tensors and
142144
nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?`
143145
in the summary table. The summary will also display `?` for layers not used in the forward pass.
146+
If there are parameters not associated with any layers or modules, the count of those parameters
147+
will be displayed in the table under `other params`. The summary will display `n/a` for module type,
148+
in size, and out size.
144149
145150
Example::
146151
@@ -235,6 +240,10 @@ def trainable_parameters(self) -> int:
235240
p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad
236241
)
237242

243+
@property
244+
def total_layer_params(self) -> int:
245+
return sum(self.param_nums)
246+
238247
@property
239248
def model_size(self) -> float:
240249
return self.total_parameters * self._precision_megabytes
@@ -292,8 +301,24 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]:
292301
arrays.append(("In sizes", [str(x) for x in self.in_sizes]))
293302
arrays.append(("Out sizes", [str(x) for x in self.out_sizes]))
294303

304+
total_leftover_params = self.total_parameters - self.total_layer_params
305+
if total_leftover_params > 0:
306+
self._add_leftover_params_to_summary(arrays, total_leftover_params)
307+
295308
return arrays
296309

310+
def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], total_leftover_params: int) -> None:
311+
"""Add summary of params not associated with module or layer to model summary."""
312+
layer_summaries = dict(arrays)
313+
layer_summaries[" "].append(" ")
314+
layer_summaries["Name"].append(LEFTOVER_PARAMS_NAME)
315+
layer_summaries["Type"].append(NOT_APPLICABLE)
316+
layer_summaries["Params"].append(get_human_readable_count(total_leftover_params))
317+
if "In sizes" in layer_summaries:
318+
layer_summaries["In sizes"].append(NOT_APPLICABLE)
319+
if "Out sizes" in layer_summaries:
320+
layer_summaries["Out sizes"].append(NOT_APPLICABLE)
321+
297322
def __str__(self) -> str:
298323
arrays = self._get_summary_data()
299324

src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
get_human_readable_count,
2626
LayerSummary,
2727
ModelSummary,
28+
NOT_APPLICABLE,
2829
)
2930

3031

@@ -96,4 +97,14 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]:
9697
arrays.append(("In sizes", [str(x) for x in self.in_sizes]))
9798
arrays.append(("Out sizes", [str(x) for x in self.out_sizes]))
9899

100+
total_leftover_params = self.total_parameters - self.total_layer_params
101+
if total_leftover_params > 0:
102+
self._add_leftover_params_to_summary(arrays, total_leftover_params)
103+
99104
return arrays
105+
106+
def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], total_leftover_params: int) -> None:
107+
"""Add summary of params not associated with module or layer to model summary."""
108+
super()._add_leftover_params_to_summary(arrays, total_leftover_params)
109+
layer_summaries = dict(arrays)
110+
layer_summaries["Params per Device"].append(NOT_APPLICABLE)

tests/tests_pytorch/utilities/test_model_summary.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from collections import OrderedDict
1415
from typing import Any
1516

1617
import pytest
@@ -19,7 +20,13 @@
1920

2021
from lightning.pytorch import LightningModule, Trainer
2122
from lightning.pytorch.demos.boring_classes import BoringModel
22-
from lightning.pytorch.utilities.model_summary.model_summary import ModelSummary, summarize, UNKNOWN_SIZE
23+
from lightning.pytorch.utilities.model_summary.model_summary import (
24+
LEFTOVER_PARAMS_NAME,
25+
ModelSummary,
26+
NOT_APPLICABLE,
27+
summarize,
28+
UNKNOWN_SIZE,
29+
)
2330
from tests_pytorch.helpers.advanced_models import ParityModuleRNN
2431
from tests_pytorch.helpers.runif import RunIf
2532

@@ -137,6 +144,18 @@ def forward(self, inp):
137144
return self.head(self.branch1(inp), self.branch2(inp))
138145

139146

147+
class NonLayerParamsModel(LightningModule):
148+
"""A model with parameters not associated with pytorch layer."""
149+
150+
def __init__(self):
151+
super().__init__()
152+
self.param = torch.nn.Parameter(torch.ones(2, 2))
153+
self.layer = torch.nn.Linear(2, 2)
154+
155+
def forward(self, inp):
156+
self.layer(self.param @ inp)
157+
158+
140159
def test_invalid_max_depth():
141160
"""Test that invalid value for max_depth raises an error."""
142161
model = LightningModule()
@@ -358,3 +377,25 @@ def example_input_array(self) -> Any:
358377
summary_data = summary._get_summary_data()
359378
for column_name, entries in summary_data:
360379
assert all(isinstance(entry, str) for entry in entries)
380+
381+
382+
@pytest.mark.parametrize("example_input", [None, torch.ones(2, 2)])
383+
def test_summary_data_with_non_layer_params(example_input):
384+
model = NonLayerParamsModel()
385+
model.example_input_array = example_input
386+
387+
summary = summarize(model)
388+
summary_data = OrderedDict(summary._get_summary_data())
389+
assert summary_data[" "][-1] == " "
390+
assert summary_data["Name"][-1] == LEFTOVER_PARAMS_NAME
391+
assert summary_data["Type"][-1] == NOT_APPLICABLE
392+
assert int(summary_data["Params"][-1]) == 4
393+
if example_input is not None:
394+
assert summary_data["In sizes"][-1] == NOT_APPLICABLE
395+
assert summary_data["Out sizes"][-1] == NOT_APPLICABLE
396+
397+
398+
def test_summary_data_with_no_non_layer_params():
399+
summary = summarize(PreCalculatedModel())
400+
summary_data = OrderedDict(summary._get_summary_data())
401+
assert summary_data["Name"][-1] != LEFTOVER_PARAMS_NAME

0 commit comments

Comments
 (0)