Conversation
- Add Enzyme.jl dependency - Implement compute_grads() using Enzyme.autodiff with runtime activity mode - Add train/test mode switching to handle BatchNorm mutation issues - Refactor mlogloss to use direct indexing instead of onehotbatch - Configure Enzyme strictAliasing in module __init__ Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
…m Reactant cache Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
…rnal onehotbatch calls. Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
Signed-off-by: AdityaPandeyCN <adityapand3y666@gmail.com>
jeremiedb
left a comment
There was a problem hiding this comment.
@AdityaPandeyCN I made some comments from a fork of your PR #23
| import Optimisers | ||
| import Optimisers: OptimiserChain, WeightDecay, Adam, NAdam, Nesterov, Descent, Momentum, AdaDelta | ||
| import Flux: trainmode!, gradient, cpu, gpu | ||
| import Optimisers: OptimiserChain, WeightDecay, Momentum, Nesterov, Descent, Adam, NAdam |
There was a problem hiding this comment.
I had to add back Adam as init uses Adam and only Nadam was imported.
| outsize ÷= 2 | ||
| chain = Chain( | ||
| BatchNorm(nfeats), | ||
| BatchNorm(nfeats, track_stats=false), |
There was a problem hiding this comment.
Was this track_stats=false intended as a temporary fix? I'd expect this to result in a different behavior; and that default option true should also be compatible with Reactant
There was a problem hiding this comment.
Yes,this was a fix. I was facing some problems with BN particularly the running stats (μ and σ²) values not changing. I will look into the docs more to see how to handle it.
| offset_name=nothing, | ||
| ) | ||
|
|
||
| device = config.device |
There was a problem hiding this comment.
I think that device should still be used here, since it's a parameters defined in thelearner, and as such should be defined with the the "model constructor" (per MLJ terminology).
| using NeuroTabModels | ||
|
|
||
| using Reactant | ||
| Reactant.set_default_backend("cpu") |
There was a problem hiding this comment.
device should continue to be defined in the model contructor (ie NeuroTabRegressor at line 60 below)
|
|
||
| # desktop: 0.771839 seconds (369.20 k allocations: 1.522 GiB, 5.94% gc time) | ||
| @time p_train = m(dtrain; device=:gpu); | ||
| @time p_train = m(dtrain); No newline at end of file |
There was a problem hiding this comment.
It should still be possible to request inference execution on a specific device (cpu or gpu), according to kwarg.
| using Random: seed! | ||
|
|
||
| using Reactant | ||
| Reactant.set_default_backend("gpu") |
There was a problem hiding this comment.
Test with gpu has a backen resulted in quite poor performance, ~ 15sec which is even a little slower that cpu timing at ~14sec.
Have you observed such poor performance. This would need to be investigated as performance should be at least at within parity zone with current jdb/neuro-v3 / zygote implementation,
It shows following warnings:
[ Info: Init training
I0000 00:00:1770966229.664446 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966230.970942 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966232.270221 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966233.563376 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966234.858976 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966236.155050 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966237.464304 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966238.780420 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966240.100689 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
I0000 00:00:1770966241.403202 2648952 dot_merger.cc:481] Merging Dots in computation: main.6
14.681320 seconds (4.47 M allocations: 343.948 MiB, 0.30% gc time, 0.00% compilation time)
1.300183 seconds (4.06 k allocations: 1.673 GiB, 7.01% gc time)For reference, with jdb/neuro-v3 branch:
- cpu: ~ 21 sec (so this PR is an improvement)
- gpu: ~ 3.2 sec, so almost 5X slower. Reactant should needs to achieve performance around this 3.2 sec to be a usable substitute.
There was a problem hiding this comment.
Let me again have a look at the gpu performance. I will report you back.
| if is_full | ||
| opts_ra, m_ra = cache[:compiled_step](cache[:loss], m_ra, opts_ra, args_ra...) | ||
| else | ||
| opts_ra, m_ra = Reactant.@jit _train_step!(cache[:loss], m_ra, opts_ra, args_ra...) |
There was a problem hiding this comment.
this is going to be inefficient, you should pad the last one with zeros if needed
| @@ -61,24 +62,23 @@ function update( | |||
| while fitresult.info[:nrounds] < model.nrounds | |||
There was a problem hiding this comment.
you can even compile the whole while loop
|
@AdityaPandeyCN I've made some explorations for migrating to Lux and it looks encouraging. GPU perf seems improve meaningfully with Reactant. See branch |
|
Thanks @jeremiedb , This looks really great, I will complete the adaptation. I locally addressed the comments on this branch but some of the things got quite complex which I am assuming can be handled really well by lux's |
|
Note that something that I'm alittle concerned about regarding Lux is its loss fucntion interface: https://lux.csail.mit.edu/stable/api/Lux/utilities#Loss-Functions |
|
So something like a custom objective function instead of using MSELoss() to handle them correctly? |
|
Yeah there already are custom functions for NeuroTabModels defined in https://github.com/Evovest/NeuroTabModels.jl/blob/lux-new/src/losses.jl. These would need to be adapted for compatibility with Lux. Looking under the hood at how Lux deals with the |
No description provided.