Skip to content

Jdb/grant review#24

Open
jeremiedb wants to merge 20 commits intojdb/neuro-v3from
jdb/grant-review
Open

Jdb/grant review#24
jeremiedb wants to merge 20 commits intojdb/neuro-v3from
jdb/grant-review

Conversation

@jeremiedb
Copy link
Member

No description provided.

AdityaPandeyCN and others added 20 commits January 30, 2026 19:16
- Add Enzyme.jl dependency
- Implement compute_grads() using Enzyme.autodiff with runtime activity mode
- Add train/test mode switching to handle BatchNorm mutation issues
- Refactor mlogloss to use direct indexing instead of onehotbatch
- Configure Enzyme strictAliasing in module __init__

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
…m Reactant cache

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
…rnal onehotbatch calls.

Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
@jeremiedb jeremiedb changed the base branch from main to jdb/neuro-v3 February 13, 2026 06:52
Copy link
Member Author

@jeremiedb jeremiedb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AdityaPandeyCN I made some comments from a fork of your PR #23

import Optimisers
import Optimisers: OptimiserChain, WeightDecay, Adam, NAdam, Nesterov, Descent, Momentum, AdaDelta
import Flux: trainmode!, gradient, cpu, gpu
import Optimisers: OptimiserChain, WeightDecay, Momentum, Nesterov, Descent, Adam, NAdam
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add back Adam as init uses Adam and only Nadam was imported.

outsize ÷= 2
chain = Chain(
BatchNorm(nfeats),
BatchNorm(nfeats, track_stats=false),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this track_stats=false intended as a temporary fix? I'd expect this to result in a different behavior; and that default option true should also be compatible with Reactant

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes,this was a fix. I was facing some problems with BN particularly the running stats (μ and σ²) values not changing. I will look into the docs more to see how to handle it.

offset_name=nothing,
)

device = config.device
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that device should still be used here, since it's a parameters defined in thelearner, and as such should be defined with the the "model constructor" (per MLJ terminology).

using NeuroTabModels

using Reactant
Reactant.set_default_backend("cpu")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device should continue to be defined in the model contructor (ie NeuroTabRegressor at line 60 below)


# desktop: 0.771839 seconds (369.20 k allocations: 1.522 GiB, 5.94% gc time)
@time p_train = m(dtrain; device=:gpu);
@time p_train = m(dtrain); No newline at end of file
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should still be possible to request inference execution on a specific device (cpu or gpu), according to kwarg.

using Random: seed!

using Reactant
Reactant.set_default_backend("gpu")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test with gpu has a backen resulted in quite poor performance, ~ 15sec which is even a little slower that cpu timing at ~14sec.
Have you observed such poor performance. This would need to be investigated as performance should be at least at within parity zone with current jdb/neuro-v3 / zygote implementation,

It shows following warnings:

[ Info: Init training
I0000 00:00:1770966229.664446 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966230.970942 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966232.270221 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966233.563376 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966234.858976 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966236.155050 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966237.464304 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966238.780420 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966240.100689 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966241.403202 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
 14.681320 seconds (4.47 M allocations: 343.948 MiB, 0.30% gc time, 0.00% compilation time)
  1.300183 seconds (4.06 k allocations: 1.673 GiB, 7.01% gc time)

For reference, with jdb/neuro-v3 branch:

  • cpu: ~ 21 sec (so this PR is an improvement)
  • gpu: ~ 3.2 sec, so almost 5X slower. Reactant should needs to achieve performance around this 3.2 sec to be a usable substitute.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me again have a look at the gpu performance. I will report you back.

if is_full
opts_ra, m_ra = cache[:compiled_step](cache[:loss], m_ra, opts_ra, args_ra...)
else
opts_ra, m_ra = Reactant.@jit _train_step!(cache[:loss], m_ra, opts_ra, args_ra...)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is going to be inefficient, you should pad the last one with zeros if needed

@@ -61,24 +62,23 @@ function update(
while fitresult.info[:nrounds] < model.nrounds
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can even compile the whole while loop

@jeremiedb
Copy link
Member Author

@AdityaPandeyCN I've made some explorations for migrating to Lux and it looks encouraging. GPU perf seems improve meaningfully with Reactant. See branch lux-new: https://github.com/Evovest/NeuroTabModels.jl/tree/lux-new
Notably, the logic and Device get much simplified and quite aligned with basic Lux tutorials.
Note that inference (ie m(x)) and eval metrics tracking no longer work as adaptation is needed.
But I think these should be quite straightforward.
If it makes sense to you, I'd suggest to start from this branch, and complete the adaptation, notably the eval callback / inference support, and having proper functionality to allow training on either cpu (cpu_device) or GPU / Reactant-device, etc.

@AdityaPandeyCN
Copy link

Thanks @jeremiedb , This looks really great, I will complete the adaptation. I locally addressed the comments on this branch but some of the things got quite complex which I am assuming can be handled really well by lux's trainstate

@jeremiedb
Copy link
Member Author

Note that something that I'm alittle concerned about regarding Lux is its loss fucntion interface: https://lux.csail.mit.edu/stable/api/Lux/utilities#Loss-Functions
It seems to assume only x, y pair, while in NeuroTabModels, there's the need to support loss where a weights and/or offset vector are also present. It would be worth taking a look at this to validate that such 3 and 4 data input to the loss function can be adapted to respect Lux interface.

@AdityaPandeyCN
Copy link

So something like a custom objective function instead of using MSELoss() to handle them correctly?

@jeremiedb
Copy link
Member Author

Yeah there already are custom functions for NeuroTabModels defined in https://github.com/Evovest/NeuroTabModels.jl/blob/lux-new/src/losses.jl. These would need to be adapted for compatibility with Lux.

Looking under the hood at how Lux deals with the x, y pairs, it appears that the the gradients get updated by trating them just as a single data object, hence a tuple. So I'd expect that a loss function acting on (x,y,w) or (x, y, w, offset) to work just fine.
So I don't think there's a need to worry about this on your end for now, and instead focus on the cleaning up the implementation to have inference and tracking eval metrics to work (ex with apply(m, x, ps, st)), as well as the StackTree which was simplified into a single NeuroTree (a custom wrapper layer like https://lux.csail.mit.edu/stable/introduction/#Defining-Custom-Layers shall likely work).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments