Add type hints and standardize docstrings in torchutils#1815
Add type hints and standardize docstrings in torchutils#1815janfb merged 3 commits intosbi-dev:mainfrom
Conversation
- Added type hints to 16 functions: tile, sum_except_batch, split_leading_dim, merge_leading_dims, repeat_rows, tensor2numpy, logabsdet, random_orthogonal, get_num_parameters, create_alternating_binary_mask, create_mid_split_binary_mask, create_random_binary_mask, searchsorted, cbrt, get_temperature, gaussian_kde_log_eval - Converted :param/:return docstrings to Google style (Args/Returns) for consistency with the rest of the file - Fixed get_temperature to use torch.ones(1) instead of 1
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1815 +/- ##
=======================================
Coverage 87.86% 87.86%
=======================================
Files 140 140
Lines 12726 12726
=======================================
Hits 11182 11182
Misses 1544 1544
Flags with carried forward coverage won't be shown. Click here to find out more.
|
janfb
left a comment
There was a problem hiding this comment.
thanks a lot @khaledeslam20 , this looks good!
Added a couple of minor comments.
sbi/utils/torchutils.py
Outdated
| Union, | ||
| ) | ||
|
|
||
| from typing import List |
There was a problem hiding this comment.
duplicate import, please run ruff or all pre-commit hooks locally.
sbi/utils/torchutils.py
Outdated
| """Creates a binary mask of a given dimension which splits at the midpoint. | ||
|
|
||
| :param features: Dimension of mask. | ||
| :return: Binary mask with half of its entries set to 1s, of type torch.Tensor. | ||
| The first half of the mask is set to 1s and the second half to 0s. |
There was a problem hiding this comment.
this docstring is not correct, it describes sth like the create_mid_split_binary_mask function.
please contruct the new docstring from the old docstring description. only adapt the format to Google docstring format.
|
Both are fixed, thanks for your guidance. |
janfb
left a comment
There was a problem hiding this comment.
Looks good, thanks again @khaledeslam20 🙏
|
@khaledeslam20 before we can merge you would need to address the one type mismatch between |
|
Fixed the type mismatch in |
Summary
Adds type hints and improves docstrings for 16 utility functions in
sbi/utils/torchutils.py, as discussed in #1805.Changes
Added type hints to all 16 functions that were missing them:
tile,sum_except_batch,split_leading_dim,merge_leading_dims,repeat_rows,tensor2numpy,logabsdet,random_orthogonal,get_num_parameters,create_alternating_binary_mask,create_mid_split_binary_mask,create_random_binary_mask,searchsorted,cbrt,get_temperature,gaussian_kde_log_evalAlso standardized docstrings to Google style (Args/Returns) for the functions that were still using the old :param/:return format, to be consistent with the rest of the file.
Small fix in
get_temperature: changed1totorch.ones(1)for correct tensor comparison intorch.min.Testing
All tests pass:
pytest tests/ -k torchutils -v→ 18 passed, 1 skippedNo functionality changes (except minor
get_temperaturefix).Closes #1805