Skip to content

Commit 52a4036

Browse files
committed
Fix
1 parent ed9a92a commit 52a4036

7 files changed

Lines changed: 355 additions & 248 deletions

File tree

ext/ExaModelsGenOpt.jl

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,30 @@
1-
module ExaModelsIpopt
1+
module ExaModelsGenOpt
22

33
import ExaModels
44
import GenOpt
5+
import GenOpt: FunctionGenerator, SumGenerator, ContiguousArrayOfVariables, IteratorIndex, Iterator
6+
import MathOptInterface as MOI
57

6-
function copy_generator_constraints!(c, moim, cis, var_to_idx, con_to_idx, T, ::Type{F}, ::Type{S}) where {F<:GenOpt.FunctionGenerator}
7-
cis = MOI.get(moim, MOI.ListOfConstraintIndices{F,S}())
8+
# Mark GenOpt function types as extension types
9+
ExaModels.is_extension_type(::Type{<:FunctionGenerator}) = true
10+
ExaModels.is_extension_type(::Type{<:SumGenerator}) = true
11+
12+
# Handle SumGenerator in objective expressions
13+
function ExaModels.exafy_extension_obj_arg(m::SumGenerator)
14+
return _exagen(m.func, m.iterators)
15+
end
16+
17+
# Hook to process FunctionGenerator constraints after standard constraints
18+
function ExaModels.copy_extra_constraints!(c, moim, var_to_idx, con_to_idx, T)
19+
con_types = MOI.get(moim, MOI.ListOfConstraintTypesPresent())
20+
for (F, S) in con_types
21+
F <: FunctionGenerator || continue
22+
cis = MOI.get(moim, MOI.ListOfConstraintIndices{F,S}())
23+
_copy_generator_constraints!(c, moim, cis, var_to_idx, con_to_idx, T, S)
24+
end
25+
end
26+
27+
function _copy_generator_constraints!(c, moim, cis, var_to_idx, con_to_idx, T, ::Type{S}) where {S}
828
# FIXME we assume that `var_to_idx` is the identity
929
for ci in cis
1030
func = MOI.get(moim, MOI.ConstraintFunction(), ci)
@@ -15,4 +35,59 @@ function copy_generator_constraints!(c, moim, cis, var_to_idx, con_to_idx, T, ::
1535
end
1636
end
1737

38+
# Convert GenOpt expression trees to ExaModels format
39+
40+
exagen::Number, _) = α
41+
42+
function exagen(f::MOI.ScalarNonlinearFunction, offsets)
43+
if f.head == :getindex
44+
v = f.args[1]
45+
if v isa ContiguousArrayOfVariables
46+
idx = exagen(f.args[2], offsets)
47+
if !iszero(v.offset)
48+
idx = v.offset + idx
49+
end
50+
cp = cumprod(v.size)
51+
for i in 3:length(f.args)
52+
idx += cp[i-2] * (exagen(f.args[i], offsets) - 1)
53+
end
54+
return ExaModels.Var(idx)
55+
elseif v isa IteratorIndex
56+
@assert length(f.args) == 2
57+
@assert f.args[2] isa Integer
58+
if isnothing(offsets)
59+
@assert isone(f.args[2])
60+
return ExaModels.ParSource()
61+
else
62+
return ExaModels.ParIndexed(ExaModels.ParSource(), offsets[v.value] + f.args[2])
63+
end
64+
else
65+
error("Unexpected the first operand of `getindex` to be of type `$(typeof(v))`")
66+
end
67+
else
68+
return ExaModels.op(f.head)((exagen(e, offsets) for e in f.args)...)
69+
end
70+
end
71+
72+
function _exagen(func::MOI.ScalarNonlinearFunction, iterators)
73+
lengths = map(it -> length(first(it.values)), iterators)
74+
if length(lengths) == 1 && lengths[] == 1
75+
cs = nothing
76+
pars = only.(iterators[].values)
77+
else
78+
cs = [0; cumsum(lengths)[1:end-1]]
79+
pars = vec(map(Base.Iterators.ProductIterator(ntuple(i -> iterators[i].values, length(iterators)))) do I
80+
reduce((i, j) -> tuple(i..., j...), I)
81+
end)
82+
end
83+
expr = exagen(func, cs)
84+
return expr, pars
1885
end
86+
87+
# Bound helpers for vector sets used by FunctionGenerator constraints
88+
_lower_bounds(::Union{MOI.Zeros,MOI.Nonnegatives}, T) = zero(T)
89+
_lower_bounds(::MOI.Nonpositives, T) = typemin(T)
90+
_upper_bounds(::Union{MOI.Zeros,MOI.Nonpositives}, T) = zero(T)
91+
_upper_bounds(::MOI.Nonnegatives, T) = typemax(T)
92+
93+
end # module

0 commit comments

Comments
 (0)