1+ ###############################################
2+ #
3+ # This file severs as an example of how to
4+ # train a MLP model using the SDD dataset using
5+ # the PAL library and a spline-distribution.
6+ #
7+ ###############################################
8+
9+ import pal .problem .sdd as csdd
10+ import pal .distribution .spline_distribution as spline
11+ from pal .wmi .compute_integral import integrate_distribution
12+ import torch
13+ import torch .nn as nn
14+ from typing import Callable , Any , Generic , TypeVar
15+ import numpy as np
16+ from torch .utils .data import DataLoader , TensorDataset
17+ from tqdm import tqdm
18+ import argparse
19+
20+ T = TypeVar ("T" )
21+
22+
23+ class SimpleFC (Generic [T ], nn .Module ):
24+ def __init__ (
25+ self ,
26+ input_size : int ,
27+ output_size : int ,
28+ hidden_sizes : list [int ],
29+ final_function : Callable [[torch .Tensor ], tuple [torch .Tensor , ...]] | None = None ,
30+ final_module : nn .Module | None = None ,
31+ ) -> None :
32+ super ().__init__ ()
33+ self .fcs = []
34+ for i in range (len (hidden_sizes )):
35+ if i == 0 :
36+ self .fcs .append (nn .Linear (input_size , hidden_sizes [i ]))
37+ else :
38+ self .fcs .append (nn .Linear (hidden_sizes [i - 1 ], hidden_sizes [i ]))
39+ self .fcs .append (nn .Linear (hidden_sizes [- 1 ], output_size ))
40+ self .fcs = nn .ModuleList (self .fcs )
41+ self .final_function = final_function
42+ self .final_module = final_module
43+
44+ def network (self , x : torch .Tensor ) -> torch .Tensor :
45+ for i in range (len (self .fcs ) - 1 ):
46+ x = self .fcs [i ](x )
47+ x = nn .functional .relu (x )
48+ x = self .fcs [- 1 ](x )
49+ return x
50+
51+ def forward (self , x ) -> T :
52+ x = self .network (x )
53+ if self .final_function is not None :
54+ x = self .final_function (x )
55+ if self .final_module is not None :
56+ x = self .final_module (* x )
57+ return x
58+
59+ def __call__ (self , * args , ** kwds ) -> T :
60+ return super ().__call__ (* args , ** kwds )
61+
62+
63+ def main (args : argparse .Namespace ) -> None :
64+ sdd = csdd .SDDSingleImageTrajectory (
65+ img_id = args .img_id ,
66+ path = "./data/sdd" ,
67+ )
68+
69+ # load the constraints
70+ lra_problem = sdd .create_constraints ()
71+
72+ spline_distribution_builder = spline .SplineSQ2DBuilder (
73+ constraints = lra_problem ,
74+ var_positions = sdd .get_y_vars (),
75+ num_knots = args .num_knots ,
76+ num_mixtures = args .num_mixtures ,
77+ )
78+
79+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
80+
81+ # create the distribution
82+ conditional_spline_dist = integrate_distribution (
83+ d = spline_distribution_builder ,
84+ device = device ,
85+ precision = torch .float64 ,
86+ )
87+
88+ shape_value , shape_derivative , shape_mixture_weights = (
89+ conditional_spline_dist .parameter_shape ()
90+ )
91+
92+ # create the model
93+ net_size = args .net_size
94+ if net_size == "small" :
95+ hidden_size = [512 , 512 ]
96+ elif net_size == "medium" :
97+ hidden_size = [1024 , 1024 ]
98+ elif net_size == "large" :
99+ hidden_size = [2048 , 2048 ]
100+
101+ input_size = np .prod (sdd .get_x_shape ())
102+
103+ total_output_size = (
104+ np .prod (shape_value )
105+ + np .prod (shape_derivative )
106+ + np .prod (shape_mixture_weights )
107+ )
108+
109+ def reparam (out_nn : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
110+ param_mixture_weights = out_nn [:, :num_mixture_param ].softmax (dim = - 1 )
111+
112+ param_dens_value = out_nn [:, num_mixture_param :(num_mixture_param + num_dens_knots_values )]
113+ param_dens_value = param_dens_value .reshape (- 1 , * shape_value )
114+
115+ param_dens_derivative = out_nn [:, (num_mixture_param + num_dens_knots_values ):]
116+ param_dens_derivative = param_deriv_dens_scale * param_dens_derivative .reshape (- 1 , * shape_derivative )
117+
118+ return param_dens_value , param_dens_derivative , param_mixture_weights
119+
120+ model = SimpleFC [spline .SplineSQ2D ](
121+ input_size = input_size ,
122+ output_size = total_output_size ,
123+ hidden_sizes = hidden_size ,
124+ final_function = reparam ,
125+ final_module = conditional_spline_dist ,
126+ ).to (device )
127+
128+ # create the optimizer
129+ optimizer = torch .optim .Adam (model .parameters (), lr = args .lr )
130+
131+ print ("Loading dataset" )
132+ dataset = sdd .load_dataset ()
133+ dataset_train = dataset .train
134+ dataset_val = dataset .val
135+ dataset_test = dataset .test
136+ print (
137+ f"Train size: { len (dataset_train )} , Val size: { len (dataset_val )} , Test size: { len (dataset_test )} "
138+ )
139+
140+ # check if random seed is set
141+ if args .seed is not None :
142+ torch .manual_seed (args .seed )
143+ np .random .seed (args .seed )
144+
145+ model .to (device )
146+ conditional_spline_dist .to (device )
147+ precision = torch .float64 if args .use_float64 else torch .float32
148+ model .to (precision )
149+
150+ batch_size = args .batch_size
151+
152+ loader = DataLoader (
153+ dataset_train ,
154+ batch_size = batch_size ,
155+ shuffle = True ,
156+ pin_memory = False ,
157+ num_workers = 10 ,
158+ )
159+
160+ loader_val = DataLoader (dataset_val , batch_size = batch_size )
161+
162+ loader_test = DataLoader (dataset_test , batch_size = batch_size )
163+
164+ if args .init_last_layer_positive :
165+ num_mixture_param = np .prod (shape_mixture_weights )
166+ num_dens_knots_values = np .prod (shape_value )
167+
168+ with torch .no_grad ():
169+ last_layer = model .fcs [- 1 ]
170+ pos_sub = 0.1 * torch .abs (
171+ last_layer .weight .data [
172+ num_mixture_param :(num_mixture_param + num_dens_knots_values )
173+ ]
174+ )
175+ last_layer .weight .data [
176+ num_mixture_param :(num_mixture_param + num_dens_knots_values )
177+ ] = pos_sub
178+ last_layer .bias .data = torch .zeros_like (last_layer .bias .data )
179+
180+ epochs = args .epochs
181+
182+ param_deriv_dens_scale = 0.1
183+
184+ for epoch in tqdm (range (epochs ), desc = "Epochs" ):
185+ model .train ()
186+ for i , (x , y ) in enumerate (tqdm (loader , desc = "Training" , leave = False )):
187+ x = x .to (device ).to (precision )
188+ y = y .to (device ).to (precision )
189+
190+ # forward pass
191+ log_dens = model (x ).log_dens (y )
192+
193+ loss = - log_dens .mean ()
194+ # backward pass
195+ optimizer .zero_grad ()
196+ loss .backward ()
197+ optimizer .step ()
198+
199+ # validation
200+ model .eval ()
201+ with torch .no_grad ():
202+ val_ll = []
203+ for i , (x , y ) in enumerate (tqdm (loader_val , desc = "Validation" , leave = False )):
204+ x = x .to (device ).to (precision )
205+ y = y .to (device ).to (precision )
206+
207+ log_dens = model (x ).log_dens (y )
208+ val_ll .append (log_dens .to ("cpu" ))
209+ val_ll = torch .cat (val_ll ).mean ()
210+ print (f"Epoch { epoch } : Validation log-likelihood: { val_ll :.4f} " )
211+
212+ # test
213+ with torch .no_grad ():
214+ test_ll = []
215+ for i , (x , y ) in enumerate (tqdm (loader_test , desc = "Test" , leave = False )):
216+ x = x .to (device ).to (precision )
217+ y = y .to (device ).to (precision )
218+
219+ log_dens = model (x ).log_dens (y )
220+ test_ll .append (log_dens .to ("cpu" ))
221+ test_ll = torch .cat (test_ll ).mean ()
222+ print (f"Test log-likelihood: { test_ll :.4f} " )
223+
224+
225+ def args ():
226+ parser = argparse .ArgumentParser (description = "Train a MLP model using the SDD dataset" )
227+ parser .add_argument (
228+ "--img_id" ,
229+ type = int ,
230+ default = 12 ,
231+ help = "Image ID to use for the SDD dataset" ,
232+ )
233+ parser .add_argument (
234+ "--num_knots" ,
235+ type = int ,
236+ default = 10 ,
237+ help = "Number of knots to use for the spline distribution" ,
238+ )
239+ parser .add_argument (
240+ "--num_mixtures" ,
241+ type = int ,
242+ default = 5 ,
243+ help = "Number of mixtures to use for the spline distribution" ,
244+ )
245+ parser .add_argument (
246+ "--net_size" ,
247+ type = str ,
248+ choices = ["small" , "medium" , "large" ],
249+ default = "medium" ,
250+ help = "Size of the neural network" ,
251+ )
252+ parser .add_argument (
253+ "--batch_size" ,
254+ type = int ,
255+ default = 64 ,
256+ help = "Batch size for training" ,
257+ )
258+ parser .add_argument (
259+ "--lr" ,
260+ type = float ,
261+ default = 1e-3 ,
262+ help = "Learning rate for the optimizer" ,
263+ )
264+ parser .add_argument (
265+ "--epochs" ,
266+ type = int ,
267+ default = 20 ,
268+ help = "Number of epochs to train for" ,
269+ )
270+ parser .add_argument (
271+ "--seed" ,
272+ type = int ,
273+ default = None ,
274+ help = "Random seed for reproducibility" ,
275+ )
276+ parser .add_argument (
277+ "--use_float64" ,
278+ action = "store_true" ,
279+ help = "Use float64 precision instead of float32" ,
280+ )
281+ parser .add_argument (
282+ "--init_last_layer_positive" ,
283+ action = "store_true" ,
284+ help = "Initialize the last layer of the network to be positive" ,
285+ )
286+
287+ return parser .parse_args ()
288+
289+
290+ if __name__ == "__main__" :
291+ args = args ()
292+ main (args )
0 commit comments