Skip to content

Lux new#26

Open
jeremiedb wants to merge 84 commits intomainfrom
lux-new
Open

Lux new#26
jeremiedb wants to merge 84 commits intomainfrom
lux-new

Conversation

@jeremiedb
Copy link
Member

Migration from v0.3 to v0.4 by moving from Flux/Zygote to Lux/Reactant

jeremiedb and others added 30 commits January 20, 2026 22:10
- Add Enzyme.jl dependency
- Implement compute_grads() using Enzyme.autodiff with runtime activity mode
- Add train/test mode switching to handle BatchNorm mutation issues
- Refactor mlogloss to use direct indexing instead of onehotbatch
- Configure Enzyme strictAliasing in module __init__

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
…m Reactant cache

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
…rnal onehotbatch calls.

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
AdityaPandeyCN and others added 24 commits February 19, 2026 23:45
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

fix dimension flow

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
* Replace Zygote with Enzyme for gradient computation

- Add Enzyme.jl dependency
- Implement compute_grads() using Enzyme.autodiff with runtime activity mode
- Add train/test mode switching to handle BatchNorm mutation issues
- Refactor mlogloss to use direct indexing instead of onehotbatch
- Configure Enzyme strictAliasing in module __init__

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* revert back loss logic

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* up

* up

* use reactant and flux approach

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* formatting

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* revert changes

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* add back imports

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* update callback.jl

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* revert back NAdam and remove CUDA dependency

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* Compile gradient function once in init(), reuse across all epochs

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* remove CuDNN dependency

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* refactor fit.jl

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* Add _sync_to_cpu! to MLJ interface to synchronize trained weights from Reactant cache

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* Update mlogloss to accept pre-encoded one-hot matrices, removing internal onehotbatch calls.

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* make track_stats false

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* update benchmark file

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* simplify cpu sync logic

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* clean MLJ.jl

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* fix imports

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* fix inference, callbacks, and classification

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* simplify get_device

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* model cleanup

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* fix imports

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* fix imports

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* add gpu inference support

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* revert test

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* cleanup

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* merge conflicts

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* merge conflicts

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* Switch to lazy data loading to reduce memory usage and replace custom  with native  for cleaner residual stacking.

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* revert Project.toml dependency

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* revert comments

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* revert y_pred to p

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* operate callback directly on active object

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* cleanup callback.jl

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* handle MLogLoss with Lux.CrossEntropyLoss

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* use train state for metrics

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* fix data loader

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* fix fit function arguments

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* fix dimension flow

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

fix dimension flow

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* file changes cleanup
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* fix array typing

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* sync

* up

* up

* apply partial=false and fix infer

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* revert project.toml

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>

* up

* up

* bump version

---------

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Co-authored-by: AdityaPandeyCN <adityapand3y666@gmail.com>
@jeremiedb
Copy link
Member Author

@AdityaPandeyCN I've just merged the jdb/grant-review along your branch into lux-new to have the bulk of the Lux adaptation in lux-new.
However, issues with deval and infers still seem present:

  • If deval is specified, training on gpu crashes with Scalar indexing. And on CPU, it having using dtrain as deval doubles the training time, which is abnormal. Evaluation shouldn't take more than a single epoch training time.
  • Inference should returns return an inference for each of the requested datapoint. By using partial=false for DataIter, this condition can't be met. The idea of having partial=false was just to simplify debugging process. It should be possible to get a complete inference efficiently with Lux tools (might be a comination of loop on complete batch using the cmpiled batch inference (ex: apply(x, ps, st)) along a single @jit apply step for the last batch)

@AdityaPandeyCN
Copy link

AdityaPandeyCN commented Feb 22, 2026

@jeremiedb Let me have a look at it again, I was able to run this locally on gpu, but may have missed committing some part of the code.

@AdityaPandeyCN
Copy link

AdityaPandeyCN commented Feb 22, 2026

Hello @jeremiedb I have raised a PR(#27) with the fix for the above problems:-

  1. Inference is working now.
  2. For the deval case when running on GPU, the timings are good(not much difference) but in the CPU case the timings are still a big problem. I managed to bring them bit lower but still not very good. I have tried taking more things under compile to fast up but still not good for the cpu case.

@jeremiedb jeremiedb changed the base branch from jdb/neuro-v3 to main February 25, 2026 19:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants