Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
29b3705
up
jeremiedb Jan 21, 2026
4ab436f
up
jeremiedb Jan 21, 2026
57e10fb
Update model.jl
jeremiedb Jan 21, 2026
5ea33c4
cleanup
jeremiedb Jan 23, 2026
5024688
up
jeremiedb Jan 24, 2026
31776f4
up
jeremiedb Jan 24, 2026
0859d48
up
jeremiedb Jan 24, 2026
97bd3da
up
jeremiedb Jan 26, 2026
4bbdc79
up
jeremiedb Jan 28, 2026
e2a3dfc
Replace Zygote with Enzyme for gradient computation
AdityaPandeyCN Jan 30, 2026
e877c06
revert back loss logic
AdityaPandeyCN Jan 31, 2026
2fe0010
up
jeremiedb Jan 31, 2026
9e3d29d
up
jeremiedb Jan 31, 2026
020d4d0
fit return cpu model
jeremiedb Feb 4, 2026
a342bd5
use reactant and flux approach
AdityaPandeyCN Feb 5, 2026
7c3d33e
formatting
AdityaPandeyCN Feb 5, 2026
c2ad582
revert changes
AdityaPandeyCN Feb 6, 2026
c0fc0ff
add back imports
AdityaPandeyCN Feb 6, 2026
01477ac
update callback.jl
AdityaPandeyCN Feb 6, 2026
9c832e7
revert back NAdam and remove CUDA dependency
AdityaPandeyCN Feb 6, 2026
389d624
Compile gradient function once in init(), reuse across all epochs
AdityaPandeyCN Feb 9, 2026
afcc486
remove CuDNN dependency
AdityaPandeyCN Feb 12, 2026
4258de8
refactor fit.jl
AdityaPandeyCN Feb 12, 2026
30897d0
Add _sync_to_cpu! to MLJ interface to synchronize trained weights fro…
AdityaPandeyCN Feb 12, 2026
7950273
Update mlogloss to accept pre-encoded one-hot matrices, removing inte…
AdityaPandeyCN Feb 12, 2026
5587aed
make track_stats false
AdityaPandeyCN Feb 12, 2026
08974b4
update benchmark file
AdityaPandeyCN Feb 12, 2026
a0bfa18
simplify cpu sync logic
AdityaPandeyCN Feb 13, 2026
c8b14d5
clean MLJ.jl
AdityaPandeyCN Feb 13, 2026
2388bcd
up
jeremiedb Feb 13, 2026
1066ecd
fix imports
AdityaPandeyCN Feb 13, 2026
19f5c25
up
jeremiedb Feb 13, 2026
27fe14d
Merge branch 'reactantv1' of github.com:AdityaPandeyCN/NeuroTabModels…
jeremiedb Feb 13, 2026
9a984a2
fix data loader
jeremiedb Feb 13, 2026
0ba7aa3
Merge branch 'jdb/neuro-v3' of github.com:Evovest/NeuroTabModels.jl i…
jeremiedb Feb 13, 2026
94c667b
up
jeremiedb Feb 15, 2026
1acb975
up
jeremiedb Feb 15, 2026
5c6f97b
up
jeremiedb Feb 15, 2026
b8becdf
up
jeremiedb Feb 15, 2026
8c02e99
up
jeremiedb Feb 15, 2026
47bd880
up
jeremiedb Feb 15, 2026
e8a28e7
up
jeremiedb Feb 15, 2026
3ca7a83
up
jeremiedb Feb 15, 2026
c7fa43a
fix inference, callbacks, and classification
AdityaPandeyCN Feb 15, 2026
88cf9a7
simplify get_device
AdityaPandeyCN Feb 16, 2026
24f3fde
up
jeremiedb Feb 16, 2026
be881a3
model cleanup
AdityaPandeyCN Feb 16, 2026
9dad435
fix imports
AdityaPandeyCN Feb 16, 2026
b6b47f4
fix imports
AdityaPandeyCN Feb 16, 2026
674ef55
add gpu inference support
AdityaPandeyCN Feb 16, 2026
2a63954
revert test
AdityaPandeyCN Feb 16, 2026
8f0f70f
cleanup
AdityaPandeyCN Feb 17, 2026
0d70551
merge conflicts
AdityaPandeyCN Feb 17, 2026
d4984a2
merge conflicts
AdityaPandeyCN Feb 17, 2026
c0799f1
merge conflicts
AdityaPandeyCN Feb 17, 2026
75c6b53
Switch to lazy data loading to reduce memory usage and replace custom…
AdityaPandeyCN Feb 19, 2026
a329289
revert Project.toml dependency
AdityaPandeyCN Feb 19, 2026
e704a38
revert comments
AdityaPandeyCN Feb 19, 2026
20c9b74
revert y_pred to p
AdityaPandeyCN Feb 19, 2026
7082adb
operate callback directly on active object
AdityaPandeyCN Feb 20, 2026
dc129e4
cleanup callback.jl
AdityaPandeyCN Feb 20, 2026
8abf99b
handle MLogLoss with Lux.CrossEntropyLoss
AdityaPandeyCN Feb 20, 2026
6ed2364
use train state for metrics
AdityaPandeyCN Feb 20, 2026
f72db19
fix data loader
AdityaPandeyCN Feb 20, 2026
9fc003f
fix fit function arguments
AdityaPandeyCN Feb 20, 2026
52a662d
fix dimension flow
AdityaPandeyCN Feb 20, 2026
162094d
file changes cleanup
AdityaPandeyCN Feb 20, 2026
916de8b
fix array typing
AdityaPandeyCN Feb 20, 2026
b749fa8
sync
jeremiedb Feb 20, 2026
62f0a00
sync
jeremiedb Feb 20, 2026
2385451
up
jeremiedb Feb 20, 2026
13c106d
up
jeremiedb Feb 20, 2026
9ab274d
up
jeremiedb Feb 20, 2026
ea2b5de
apply partial=false and fix infer
AdityaPandeyCN Feb 21, 2026
7461b0f
revert project.toml
AdityaPandeyCN Feb 21, 2026
e0d7da7
up
jeremiedb Feb 22, 2026
685375d
Jdb/grant review (#24)
jeremiedb Feb 22, 2026
f3ea6f0
Merge branch 'main' into jdb/neuro-v3
jeremiedb Feb 22, 2026
efdc333
Merge branch 'jdb/neuro-v3' into lux-new
jeremiedb Feb 22, 2026
309408f
up
jeremiedb Feb 22, 2026
efded44
Merge branch 'aditya/lux' of github.com:AdityaPandeyCN/NeuroTabModels…
jeremiedb Feb 22, 2026
491e68d
cleanup
jeremiedb Feb 22, 2026
873eb2d
up
jeremiedb Feb 26, 2026
8784a1f
up
jeremiedb Feb 26, 2026
7db9d34
Fix inference and crash from use of deval (#27)
AdityaPandeyCN Feb 27, 2026
69641bc
up
jeremiedb Feb 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,43 +1,40 @@
name = "NeuroTabModels"
uuid = "f03403ce-56d7-46f9-9b5e-ff6add8ca7b3"
authors = ["jeremie.db <jeremie.db@evovest.com>"]
version = "0.3.0"
version = "0.4.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
CUDA = "5"
CategoricalArrays = "1"
ChainRulesCore = "1"
DataFrames = "1.3"
Flux = "0.16"
Functors = "0.5"
Lux = "1"
MLJModelInterface = "1.2.1"
MLUtils = "0.4"
NNlib = "0.9"
Optimisers = "0.4"
Random = "1"
Statistics = "1"
StatsBase = "0.34"
Tables = "1.9"
cuDNN = "1"
julia = "1.10"

[extras]
# MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
10 changes: 5 additions & 5 deletions benchmarks/YEAR-regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ arch = NeuroTabModels.NeuroTreeConfig(;
# MLE_tree_split=false
# )

device = :gpu
device = :cpu
loss = :mse # :mse :gaussian_mle :tweedie

learner = NeuroTabRegressor(
arch;
loss,
nrounds=200,
nrounds=20,
early_stopping_rounds=2,
lr=1e-3,
batchsize=1024,
Expand All @@ -81,18 +81,18 @@ learner = NeuroTabRegressor(
m = NeuroTabModels.fit(
learner,
dtrain;
deval,
# deval,
target_name,
feature_names,
print_every_n=5,
)

p_eval = m(deval; device);
p_eval = m(deval; device=:cpu);
p_eval = p_eval[:, 1]
mse_eval = mean((p_eval .- deval.y_norm) .^ 2)
@info "MSE - deval" mse_eval

p_test = m(dtest; device);
p_test = m(dtest; device=:cpu);
p_test = p_test[:, 1]
mse_test = mean((p_test .- dtest.y_norm) .^ 2) * std(df_tot.y_raw)^2
@info "MSE - dtest" mse_test
15 changes: 8 additions & 7 deletions benchmarks/benchmark_mse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,19 @@ learner = NeuroTabRegressor(
device=:gpu
)

# desktop gpu - no-eval: 11.239302 seconds (30.63 M allocations: 6.101 GiB, 3.60% gc time)
# desktop gpu - eval: 15.708276 seconds (23.57 k allocations: 13.156 GiB, 2.02% gc time)
# Reactant GPU: 5.970480 seconds (2.33 M allocations: 5.242 GiB, 3.80% gc time, 0.00% compilation time)
# Zygote GPU: 9.855853 seconds (27.92 M allocations: 6.005 GiB, 3.58% gc time)
# 13.557744 seconds (26.40 M allocations: 5.989 GiB, 9.60% gc time)
@time m = NeuroTabModels.fit(
learner,
dtrain;
# deval=dtrain,
# deval=dtrain, # FIXME: very slow when deval is used / crashed on GPU
target_name,
feature_names,
print_every_n=10,
print_every_n=2,
);

# desktop gpu: 0.947362 seconds (484.31 k allocations: 1.526 GiB, 19.83% gc time)
# desktop cpu: 15.708276 seconds (23.57 k allocations: 13.156 GiB, 2.02% gc time)
@time p_train = m(dtrain; device=:cpu);
# Reactant CPU: 0.952495 seconds (57.96 k allocations: 1.517 GiB, 0.23% gc time, 0.00% compilation time)
# Reactant CPU: 10.326071 seconds (29.30 k allocations: 13.145 GiB, 1.97% gc time)
# FIXME: need to adapt infer: returns only full batches: length of p_train must be == nrow(dtrain)
@time p_train = m(dtrain; device=:gpu);
29 changes: 5 additions & 24 deletions benchmarks/titanic-logloss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,48 +41,29 @@ arch = NeuroTabModels.NeuroTreeConfig(;
stack_size=1,
hidden_size=1,
actA=:identity,
MLE_tree_split=false,
)
# arch = NeuroTabModels.MLPConfig(;
# act=:relu,
# stack_size=1,
# hidden_size=64,
# )

learner = NeuroTabRegressor(
arch;
loss=:logloss,
nrounds=200,
loss=:logloss, # FIXME: gaussian_mle don't train
nrounds=100,
early_stopping_rounds=2,
lr=3e-2,
device=:cpu
)

# learner = NeuroTabRegressor(;
# arch_name="NeuroTreeConfig",
# arch_config=Dict(
# :actA => :identity,
# :init_scale => 1.0,
# :depth => 4,
# :ntrees => 32,
# :stack_size => 1,
# :hidden_size => 1),
# loss=:logloss,
# nrounds=400,
# early_stopping_rounds=2,
# lr=1e-2,
# )

@time m = NeuroTabModels.fit(
learner,
dtrain;
deval,
# deval, # FIXME: important slowdown when deval is used
target_name,
feature_names,
print_every_n=10,
);

# commented out - inference not yet adapted
p_train = m(dtrain)
p_eval = m(deval)

@info mean((p_train .> 0.5) .== (dtrain[!, target_name] .> 0.5))
@info mean((p_eval .> 0.5) .== (deval[!, target_name] .> 0.5))
27 changes: 20 additions & 7 deletions experiments/dataloader.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using NeuroTabModels
using DataFrames
using CategoricalArrays
using Lux

#################################
# vanilla DataFrame
Expand All @@ -9,7 +10,8 @@ nobs = 100
nfeats = 10
x = rand(nobs, nfeats);
df = DataFrame(x, :auto);
df.y = rand(nobs);
y = rand(nobs);
df.y = y;

target_name = "y"
feature_names = Symbol.(setdiff(names(df), [target_name]))
Expand All @@ -19,24 +21,35 @@ batchsize = 32
# CPU
###################################
device = :cpu
dtrain = NeuroTabModels.get_df_loader_train(df; feature_names, target_name, batchsize, device)

dtrain = NeuroTabModels.Data.get_df_loader_train(df; feature_names, target_name, batchsize, device)
for d in dtrain
@info length(d)
@info size(d[1])
@info size(d[1]), size(d[2])
end

deval = NeuroTabModels.get_df_loader_infer(df; feature_names, batchsize=32)
deval = NeuroTabModels.Data.get_df_loader_infer(df; feature_names, batchsize=32)
for d in deval
@info size(d)
end

###################################
# LuxDevice
###################################
# dev = reactant_device()
dev = cpu_device()
# dev = gpu_device()
dtrain = NeuroTabModels.Data.get_df_loader_train(df; feature_names, target_name, batchsize) |> dev
for d in dtrain
@info length(d)
@info size(d[1])
@info typeof(d[1])
end

###################################
# GPU
###################################
device = :gpu
dtrain = NeuroTabModels.get_df_loader_train(df; feature_names, target_name, batchsize, device)

dtrain = NeuroTabModels.Data.get_df_loader_train(df; feature_names, target_name, batchsize, device)
for d in dtrain
@info length(d)
@info size(d[1])
Expand Down
65 changes: 65 additions & 0 deletions experiments/lux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using NeuroTabModels
using Lux, LuxCore
using Random

using NeuroTabModels.Models.NeuroTrees: get_logits_mask, get_softplus_mask
using NNlib: softplus
using Reactant
using Enzyme
using Optimisers

rng = Random.Xoshiro(123)
nobs = 1000
nfeats = 10

# m = NeuroTabModels.Models.NeuroTrees.NeuroTree(nfeats => 1; depth=4, trees=64)
m = Chain(
NeuroTabModels.Models.NeuroTrees.NeuroTree(nfeats => 1; depth=4, trees=64)
)
x = randn(rng, Float32, nfeats, nobs)
y = randn(rng, Float32, 1, nobs)
ps, st = LuxCore.setup(rng, m)
p, st = m(x, ps, st)

# Get the device determined by Lux
# Reactant.set_default_backend("gpu")
Reactant.set_default_backend("gpu")
dev = reactant_device()
# dev = gpu_device()
# dev = cpu_device()

# Parameter and State Variables
ps, st = Lux.setup(rng, m) |> dev

# Dummy Input
x = rand(rng, Float32, nfeats, nobs) |> dev

# Run the model
## We need to use @jit to compile and run the model with Reactant
@time p, st = @jit Lux.apply(m, x, ps, st)
# @time Lux.apply(m, x, ps, st)

## For best performance, first compile the model with Reactant and then run it
@time apply_compiled = @compile Lux.apply(m, x, ps, st)
@time apply_compiled(m, x, ps, st)

# Run the model
# Gradients
ts = Training.TrainState(m, ps, st, Adam(0.001f0))
gs, loss, stats, ts = Lux.Training.compute_gradients(
AutoEnzyme(),
MSELoss(),
(x, y),
ts
)

## Optimization
ts = Training.apply_gradients!(ts, gs) # or Training.apply_gradients (no `!` at the end)

# Both these steps can be combined into a single call (preferred approach)
@time gs, loss, stats, ts = Training.single_train_step!(
AutoEnzyme(),
MSELoss(),
(x, y),
ts
);
Loading
Loading