Skip to content

Add type hints and standardize docstrings in torchutils#1815

Merged
janfb merged 3 commits intosbi-dev:mainfrom
khaledeslam20:add-type-hints-torchutils
Mar 17, 2026
Merged

Add type hints and standardize docstrings in torchutils#1815
janfb merged 3 commits intosbi-dev:mainfrom
khaledeslam20:add-type-hints-torchutils

Conversation

@khaledeslam20
Copy link
Contributor

Summary

Adds type hints and improves docstrings for 16 utility functions in sbi/utils/torchutils.py, as discussed in #1805.

Changes

Added type hints to all 16 functions that were missing them:
tile, sum_except_batch, split_leading_dim, merge_leading_dims, repeat_rows, tensor2numpy, logabsdet, random_orthogonal, get_num_parameters, create_alternating_binary_mask, create_mid_split_binary_mask, create_random_binary_mask, searchsorted, cbrt, get_temperature, gaussian_kde_log_eval

Also standardized docstrings to Google style (Args/Returns) for the functions that were still using the old :param/:return format, to be consistent with the rest of the file.

Small fix in get_temperature: changed 1 to torch.ones(1) for correct tensor comparison in torch.min.

Testing

All tests pass: pytest tests/ -k torchutils -v18 passed, 1 skipped

No functionality changes (except minor get_temperature fix).

Closes #1805

- Added type hints to 16 functions: tile, sum_except_batch,
  split_leading_dim, merge_leading_dims, repeat_rows, tensor2numpy,
  logabsdet, random_orthogonal, get_num_parameters,
  create_alternating_binary_mask, create_mid_split_binary_mask,
  create_random_binary_mask, searchsorted, cbrt, get_temperature,
  gaussian_kde_log_eval
- Converted :param/:return docstrings to Google style (Args/Returns)
  for consistency with the rest of the file
- Fixed get_temperature to use torch.ones(1) instead of  1
@codecov
Copy link

codecov bot commented Mar 16, 2026

Codecov Report

❌ Patch coverage is 84.21053% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 87.86%. Comparing base (cfed0a1) to head (8919eb1).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
sbi/utils/torchutils.py 84.21% 3 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1815   +/-   ##
=======================================
  Coverage   87.86%   87.86%           
=======================================
  Files         140      140           
  Lines       12726    12726           
=======================================
  Hits        11182    11182           
  Misses       1544     1544           
Flag Coverage Δ
fast 82.72% <84.21%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/utils/torchutils.py 68.11% <84.21%> (ø)

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot @khaledeslam20 , this looks good!

Added a couple of minor comments.

Union,
)

from typing import List
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate import, please run ruff or all pre-commit hooks locally.

Comment on lines +305 to +307
"""Creates a binary mask of a given dimension which splits at the midpoint.

:param features: Dimension of mask.
:return: Binary mask with half of its entries set to 1s, of type torch.Tensor.
The first half of the mask is set to 1s and the second half to 0s.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this docstring is not correct, it describes sth like the create_mid_split_binary_mask function.

please contruct the new docstring from the old docstring description. only adapt the format to Google docstring format.

@khaledeslam20
Copy link
Contributor Author

Both are fixed, thanks for your guidance.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks again @khaledeslam20 🙏

@janfb
Copy link
Contributor

janfb commented Mar 17, 2026

@khaledeslam20 before we can merge you would need to address the one type mismatch between Tensor and float. Please fix and then verify locally by running pyright sbi in the root dir of the repo. Thanks!

@khaledeslam20
Copy link
Contributor Author

Fixed the type mismatch in get_temperature by using separate variable names for the tensor conversions. Verified locally with pyright sbi, 0 errors.

@janfb janfb merged commit 3c60a97 into sbi-dev:main Mar 17, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add type hints to tensor utility functions in torchutils.py

2 participants