fix: prevent broadcasting errors in r2_score using da.where()#1013
fix: prevent broadcasting errors in r2_score using da.where()#1013TomAugspurger merged 1 commit intodask:mainfrom
Conversation
d7f7b86 to
17d3a02
Compare
|
Thanks. I'm not entirely sure what the best action is, but I think we ought to avoid anything that triggers computation unnecessarily, including Can you say a bit more about getting |
17d3a02 to
e98c538
Compare
|
Thanks @TomAugspurger. |
|
I'm probably missing something, but why do we care that the size of the test dataset matches the size of the training dataset ( |
The goal is not for the test dataset to match the training dataset's overall size. The focus is ensuring each estimator, trained on a specific data block, receives a matching block from the test set. |
Resolves "cannot broadcast shape (nan,) to shape (nan,)" errors.
e98c538 to
9c984a0
Compare
|
@TomAugspurger I limited the PR to changes in dask_ml/metrics/regression.py. Merging this would unblock my use case and allow me to update dask. Thanks! :). (previous state) |
|
Thanks for triggering the tests, Tom. Is there anything I can/should do to fix the failing runs? The errors seem unrelated and similar to other recent runs e.g. https://github.com/dask/dask-ml/actions/runs/14849752234/job/41691028477. |
|
I'll take a look in the next few days. |
|
Thanks @wietzesuijker. There's still one error in |
Closes #1012
First PR here. Curious to hear your feedback.
Problem
After updating to Dask 2025.2.0, tests fail with a ValueError due to changes in chunk size handling.
Solution
Refactor r2_score() to use da.where() for correct broadcasting.
Testing
Test added to ensure r2_score() works correctly with arrays that have different chunk configurations.