Speed up PairPixelSampler#3452
Conversation
|
Not sure about the test failure: |
|
|
||
| unique_vals = torch.unique(tensor) | ||
| if any(val not in (0, 1) for val in unique_vals) or tensor.dtype != torch.float32: | ||
| if not (tensor.dtype == torch.float32 and torch.all((tensor == 0) | (tensor == 1))): |
There was a problem hiding this comment.
Sweet, this is so much cleaner
akristoffersen
left a comment
There was a problem hiding this comment.
This looks great to me, thanks!
Is there a reason we aren't also using the rejection sampling for the PatchPixelSampler? Maybe not in this PR, but now the rejection sampling is its own separate fn it seems like it would be a drop-in replacement.
a24cd94 to
c153976
Compare
There is no reason, and it is a drop-in replacement. However, I agree that a separate PR seems better here. I opened one here (rebased this branch and started the new PR from it). |
This speeds up the extremely slow
PairPixelSampler. The main issue seems to be with mask erosion being slow. I managed to speed it up a bit but it still isn't great. I also added the functionality from #2585 toPairPixelSampler.Before any of the changes:

With only erosion:

With both erosion and rejection sampling:

This is an initial step towards #3446