Skip to content

Commit c7b3fc3

Browse files
committed
progress towards working example
1 parent fc72232 commit c7b3fc3

5 files changed

Lines changed: 321 additions & 24 deletions

File tree

pal/distribution/spline_distribution.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,15 @@ def compute_reordering_of_parameter_positions2d(
2121
) -> Dict[int, int]:
2222
"""
2323
Compute the reordering of the parameter positions for the polynomial so that
24-
it conforms to the combination.
24+
it conforms to the combination of univariate polynomials.
25+
26+
Parameters:
27+
- deg (int): The degree of the polynomial.
28+
- powers (torch.Tensor): A tensor of shape [n_mon, 2] containing the powers of the polynomial terms.
29+
Each row represents a monomial with two variables.
30+
31+
Returns:
32+
- Dict[int, int]: A mapping from the current parameter index to the reordered parameter index.
2533
"""
2634
exponents_y0 = torch.arange(deg + 1)
2735
exponents_y1 = torch.arange(deg + 1)
@@ -478,7 +486,7 @@ def get_info_1d(
478486

479487
return get_info_2d(self.knots, x)
480488

481-
def log_dens(self, x, eps=-1, with_indicator=False):
489+
def log_dens(self, x, eps=-1, with_indicator=False) -> torch.Tensor:
482490
"""
483491
Essentially computes log p(x) for the distribution, where p(x) is
484492
composed of:

pal/distribution/torch_polynomial.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ class TorchPolynomial(torch.nn.Module):
4444
using PyTorch tensors.
4545
4646
Args:
47-
coeffs (torch.Tensor): Tensor containing the coefficients of the polynomial.
48-
powers (torch.Tensor): Tensor containing the powers of the polynomial.
47+
coeffs (torch.Tensor): Tensor containing the coefficients of the polynomial. Shape: (num_terms,).
48+
powers (torch.Tensor): Tensor containing the powers of the polynomial. Shape: (num_terms, num_vars).
4949
variable_map_dict (Dict[str, int] | frozendict[str, int]): Dictionary mapping variable names to their indices.
5050
absolute (bool): Whether to take the absolute value of the polynomial. Defaults to True.
5151
"""
@@ -529,6 +529,16 @@ def calc_2d_spline_component(
529529
knot_idx: torch.Tensor, # (2)
530530
params: torch.Tensor, # (2, num_pieces, 4)
531531
) -> torch.Tensor:
532+
"""
533+
Calculate the 2D spline component for the given x and knot_idx.
534+
The spline is the product of two cubic polynomials, one for each dimension.
535+
Args:
536+
x: The input tensor.
537+
knot_idx: The indices of the knots.
538+
params: The parameters of the spline.
539+
Returns:
540+
The calculated spline component.
541+
"""
532542
def per_param(
533543
p: torch.Tensor, # (2, num_pieces)
534544
) -> torch.Tensor:
@@ -576,7 +586,7 @@ def calc_log_mixture(
576586
lambda p: calc_2d_spline_component(x, knot_idx, p)
577587
)(
578588
m_params
579-
)#.abs_() # (num_mixtures)
589+
) # .abs_() # (num_mixtures)
580590
m_normalization_coeff = m_normalization.sum(dim=-1).sum(
581591
dim=-1
582592
) # (num_mixtures)
@@ -586,17 +596,4 @@ def calc_log_mixture(
586596
torch.log(densities_components) + torch.log(m_weights), dim=0
587597
)
588598

589-
# if eps != -1:
590-
# with torch.no_grad():
591-
# # mixture_poly[mixture_poly < eps] = eps
592-
# mixture_poly.clamp_(min=eps)
593-
594-
# mixture_poly_log = 2 * torch.log(mixture_poly) - torch.log(
595-
# m_normalization_coeff
596-
# )
597-
598-
# mixture_poly_log = torch.log(mixture_poly ** 2)
599-
600-
# poly_log = torch.logsumexp(mixture_poly_log + torch.log(m_weights), dim=0)
601-
602599
return poly_log

pal/training/train_mlp_sdd.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
###############################################
2+
#
3+
# This file severs as an example of how to
4+
# train a MLP model using the SDD dataset using
5+
# the PAL library and a spline-distribution.
6+
#
7+
###############################################
8+
9+
import pal.problem.sdd as csdd
10+
import pal.distribution.spline_distribution as spline
11+
from pal.wmi.compute_integral import integrate_distribution
12+
import torch
13+
import torch.nn as nn
14+
from typing import Callable, Any, Generic, TypeVar
15+
import numpy as np
16+
from torch.utils.data import DataLoader, TensorDataset
17+
from tqdm import tqdm
18+
import argparse
19+
20+
T = TypeVar("T")
21+
22+
23+
class SimpleFC(Generic[T], nn.Module):
24+
def __init__(
25+
self,
26+
input_size: int,
27+
output_size: int,
28+
hidden_sizes: list[int],
29+
final_function: Callable[[torch.Tensor], tuple[torch.Tensor, ...]] | None = None,
30+
final_module: nn.Module | None = None,
31+
) -> None:
32+
super().__init__()
33+
self.fcs = []
34+
for i in range(len(hidden_sizes)):
35+
if i == 0:
36+
self.fcs.append(nn.Linear(input_size, hidden_sizes[i]))
37+
else:
38+
self.fcs.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
39+
self.fcs.append(nn.Linear(hidden_sizes[-1], output_size))
40+
self.fcs = nn.ModuleList(self.fcs)
41+
self.final_function = final_function
42+
self.final_module = final_module
43+
44+
def network(self, x: torch.Tensor) -> torch.Tensor:
45+
for i in range(len(self.fcs) - 1):
46+
x = self.fcs[i](x)
47+
x = nn.functional.relu(x)
48+
x = self.fcs[-1](x)
49+
return x
50+
51+
def forward(self, x) -> T:
52+
x = self.network(x)
53+
if self.final_function is not None:
54+
x = self.final_function(x)
55+
if self.final_module is not None:
56+
x = self.final_module(*x)
57+
return x
58+
59+
def __call__(self, *args, **kwds) -> T:
60+
return super().__call__(*args, **kwds)
61+
62+
63+
def main(args: argparse.Namespace) -> None:
64+
sdd = csdd.SDDSingleImageTrajectory(
65+
img_id=args.img_id,
66+
path="./data/sdd",
67+
)
68+
69+
# load the constraints
70+
lra_problem = sdd.create_constraints()
71+
72+
spline_distribution_builder = spline.SplineSQ2DBuilder(
73+
constraints=lra_problem,
74+
var_positions=sdd.get_y_vars(),
75+
num_knots=args.num_knots,
76+
num_mixtures=args.num_mixtures,
77+
)
78+
79+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80+
81+
# create the distribution
82+
conditional_spline_dist = integrate_distribution(
83+
d=spline_distribution_builder,
84+
device=device,
85+
precision=torch.float64,
86+
)
87+
88+
shape_value, shape_derivative, shape_mixture_weights = (
89+
conditional_spline_dist.parameter_shape()
90+
)
91+
92+
# create the model
93+
net_size = args.net_size
94+
if net_size == "small":
95+
hidden_size = [512, 512]
96+
elif net_size == "medium":
97+
hidden_size = [1024, 1024]
98+
elif net_size == "large":
99+
hidden_size = [2048, 2048]
100+
101+
input_size = np.prod(sdd.get_x_shape())
102+
103+
total_output_size = (
104+
np.prod(shape_value)
105+
+ np.prod(shape_derivative)
106+
+ np.prod(shape_mixture_weights)
107+
)
108+
109+
def reparam(out_nn: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
110+
param_mixture_weights = out_nn[:, :num_mixture_param].softmax(dim=-1)
111+
112+
param_dens_value = out_nn[:, num_mixture_param:(num_mixture_param + num_dens_knots_values)]
113+
param_dens_value = param_dens_value.reshape(-1, *shape_value)
114+
115+
param_dens_derivative = out_nn[:, (num_mixture_param + num_dens_knots_values):]
116+
param_dens_derivative = param_deriv_dens_scale * param_dens_derivative.reshape(-1, *shape_derivative)
117+
118+
return param_dens_value, param_dens_derivative, param_mixture_weights
119+
120+
model = SimpleFC[spline.SplineSQ2D](
121+
input_size=input_size,
122+
output_size=total_output_size,
123+
hidden_sizes=hidden_size,
124+
final_function=reparam,
125+
final_module=conditional_spline_dist,
126+
).to(device)
127+
128+
# create the optimizer
129+
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
130+
131+
print("Loading dataset")
132+
dataset = sdd.load_dataset()
133+
dataset_train = dataset.train
134+
dataset_val = dataset.val
135+
dataset_test = dataset.test
136+
print(
137+
f"Train size: {len(dataset_train)}, Val size: {len(dataset_val)}, Test size: {len(dataset_test)}"
138+
)
139+
140+
# check if random seed is set
141+
if args.seed is not None:
142+
torch.manual_seed(args.seed)
143+
np.random.seed(args.seed)
144+
145+
model.to(device)
146+
conditional_spline_dist.to(device)
147+
precision = torch.float64 if args.use_float64 else torch.float32
148+
model.to(precision)
149+
150+
batch_size = args.batch_size
151+
152+
loader = DataLoader(
153+
dataset_train,
154+
batch_size=batch_size,
155+
shuffle=True,
156+
pin_memory=False,
157+
num_workers=10,
158+
)
159+
160+
loader_val = DataLoader(dataset_val, batch_size=batch_size)
161+
162+
loader_test = DataLoader(dataset_test, batch_size=batch_size)
163+
164+
if args.init_last_layer_positive:
165+
num_mixture_param = np.prod(shape_mixture_weights)
166+
num_dens_knots_values = np.prod(shape_value)
167+
168+
with torch.no_grad():
169+
last_layer = model.fcs[-1]
170+
pos_sub = 0.1 * torch.abs(
171+
last_layer.weight.data[
172+
num_mixture_param:(num_mixture_param + num_dens_knots_values)
173+
]
174+
)
175+
last_layer.weight.data[
176+
num_mixture_param:(num_mixture_param + num_dens_knots_values)
177+
] = pos_sub
178+
last_layer.bias.data = torch.zeros_like(last_layer.bias.data)
179+
180+
epochs = args.epochs
181+
182+
param_deriv_dens_scale = 0.1
183+
184+
for epoch in tqdm(range(epochs), desc="Epochs"):
185+
model.train()
186+
for i, (x, y) in enumerate(tqdm(loader, desc="Training", leave=False)):
187+
x = x.to(device).to(precision)
188+
y = y.to(device).to(precision)
189+
190+
# forward pass
191+
log_dens = model(x).log_dens(y)
192+
193+
loss = - log_dens.mean()
194+
# backward pass
195+
optimizer.zero_grad()
196+
loss.backward()
197+
optimizer.step()
198+
199+
# validation
200+
model.eval()
201+
with torch.no_grad():
202+
val_ll = []
203+
for i, (x, y) in enumerate(tqdm(loader_val, desc="Validation", leave=False)):
204+
x = x.to(device).to(precision)
205+
y = y.to(device).to(precision)
206+
207+
log_dens = model(x).log_dens(y)
208+
val_ll.append(log_dens.to("cpu"))
209+
val_ll = torch.cat(val_ll).mean()
210+
print(f"Epoch {epoch}: Validation log-likelihood: {val_ll:.4f}")
211+
212+
# test
213+
with torch.no_grad():
214+
test_ll = []
215+
for i, (x, y) in enumerate(tqdm(loader_test, desc="Test", leave=False)):
216+
x = x.to(device).to(precision)
217+
y = y.to(device).to(precision)
218+
219+
log_dens = model(x).log_dens(y)
220+
test_ll.append(log_dens.to("cpu"))
221+
test_ll = torch.cat(test_ll).mean()
222+
print(f"Test log-likelihood: {test_ll:.4f}")
223+
224+
225+
def args():
226+
parser = argparse.ArgumentParser(description="Train a MLP model using the SDD dataset")
227+
parser.add_argument(
228+
"--img_id",
229+
type=int,
230+
default=12,
231+
help="Image ID to use for the SDD dataset",
232+
)
233+
parser.add_argument(
234+
"--num_knots",
235+
type=int,
236+
default=10,
237+
help="Number of knots to use for the spline distribution",
238+
)
239+
parser.add_argument(
240+
"--num_mixtures",
241+
type=int,
242+
default=5,
243+
help="Number of mixtures to use for the spline distribution",
244+
)
245+
parser.add_argument(
246+
"--net_size",
247+
type=str,
248+
choices=["small", "medium", "large"],
249+
default="medium",
250+
help="Size of the neural network",
251+
)
252+
parser.add_argument(
253+
"--batch_size",
254+
type=int,
255+
default=64,
256+
help="Batch size for training",
257+
)
258+
parser.add_argument(
259+
"--lr",
260+
type=float,
261+
default=1e-3,
262+
help="Learning rate for the optimizer",
263+
)
264+
parser.add_argument(
265+
"--epochs",
266+
type=int,
267+
default=20,
268+
help="Number of epochs to train for",
269+
)
270+
parser.add_argument(
271+
"--seed",
272+
type=int,
273+
default=None,
274+
help="Random seed for reproducibility",
275+
)
276+
parser.add_argument(
277+
"--use_float64",
278+
action="store_true",
279+
help="Use float64 precision instead of float32",
280+
)
281+
parser.add_argument(
282+
"--init_last_layer_positive",
283+
action="store_true",
284+
help="Initialize the last layer of the network to be positive",
285+
)
286+
287+
return parser.parse_args()
288+
289+
290+
if __name__ == "__main__":
291+
args = args()
292+
main(args)

pal/wmi/gasp/gasp/torch/wmipa/numerical_symb_integrator_pa.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def __init__(
4545
mode: IntegratorModes, # whether to integrate over the weights or the function
4646
total_degree: int, # total degree of the polynomial, so max over sum of exponents for each monomial.
4747
variable_map: dict[str, int], # name => index
48-
sum_seperately=False,
49-
with_sorting=False,
48+
sum_seperately=True,
49+
with_sorting=True,
5050
batch_size=None,
5151
monomials_lower_precision=True,
5252
n_workers=7,
@@ -58,8 +58,8 @@ def __init__(
5858
total_degree (int): The total degree of the polynomial, so max over sum of exponents for each monomial.
5959
Can be an estimate for WeightedFormulaMode.
6060
variable_map (dict[str, int]): A mapping from variable names to indices.
61-
sum_seperately (bool, optional): Whether to sum the results separately. Defaults to False.
62-
with_sorting (bool, optional): Whether to sort the results. Defaults to False.
61+
sum_seperately (bool, optional): Whether to sum the results separately. Defaults to True.
62+
with_sorting (bool, optional): Whether to sort the results. Defaults to True.
6363
batch_size (int, optional): The batch size for integration. Defaults to None.
6464
monomials_lower_precision (bool, optional): Whether to use lower precision for monomials. Defaults to True.
6565
n_workers (int, optional): The number of workers for parallel processing. Defaults to 7.

0 commit comments

Comments
 (0)