Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 66 additions & 15 deletions quantecon/_lqnash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

def nnash(A, B1, B2, R1, R2, Q1, Q2, S1, S2, W1, W2, M1, M2,
beta=1.0, tol=1e-8, max_iter=1000):
r"""
"""
Compute the limit of a Nash linear quadratic dynamic game. In this
problem, player i minimizes

Expand Down Expand Up @@ -79,7 +79,7 @@ def nnash(A, B1, B2, R1, R2, Q1, Q2, S1, S2, W1, W2, M1, M2,
"""
# == Unload parameters and make sure everything is an array == #
params = A, B1, B2, R1, R2, Q1, Q2, S1, S2, W1, W2, M1, M2
params = map(np.asarray, params)
params = tuple(map(np.asarray, params))
A, B1, B2, R1, R2, Q1, Q2, S1, S2, W1, W2, M1, M2 = params

# == Multiply A, B1, B2 by sqrt(beta) to enforce discounting == #
Expand All @@ -99,6 +99,15 @@ def nnash(A, B1, B2, R1, R2, Q1, Q2, S1, S2, W1, W2, M1, M2,
else:
k_2 = B2.shape[1]


# Precompute transposes that are reused
B1T = B1.T
B2T = B2.T
W1T = W1.T
W2T = W2.T
M1T = M1.T
M2T = M2.T

v1 = np.eye(k_1)
v2 = np.eye(k_2)
P1 = np.zeros((n, n))
Expand All @@ -111,28 +120,70 @@ def nnash(A, B1, B2, R1, R2, Q1, Q2, S1, S2, W1, W2, M1, M2,
F10 = F1
F20 = F2

G2 = solve((B2.T @ P2 @ B2)+Q2, v2)
G1 = solve((B1.T @ P1 @ B1)+Q1, v1)
H2 = G2 @ B2.T @ P2
H1 = G1 @ B1.T @ P1
# Solve small k x k systems for G1, G2 (inverse-like)
S2mat = (B2T @ P2 @ B2) + Q2
S1mat = (B1T @ P1 @ B1) + Q1
G2 = solve(S2mat, v2)
G1 = solve(S1mat, v1)

# Compute H terms with a cheaper multiplication order
# H1 = G1 @ (B1.T @ P1)
B1T_P1 = B1T @ P1
B2T_P2 = B2T @ P2
H1 = G1 @ B1T_P1
H2 = G2 @ B2T_P2

# Reusable intermediate products
H1_B2 = H1 @ B2
H2_B1 = H2 @ B1
G1_M1T = G1.dot(M1T)
G2_M2T = G2.dot(M2T)
H1_B2_plus = H1_B2 + G1_M1T
H2_B1_plus = H2_B1 + G2_M2T

H1_A = H1 @ A
H2_A = H2 @ A
G1_W1T = G1.dot(W1T)
G2_W2T = G2.dot(W2T)
H2_A_plus = H2_A + G2_W2T
H1_A_plus = H1_A + G1_W1T

# break up the computation of F1, F2
F1_left = v1 - ((H1 @ B2 + G1.dot(M1.T)) @
(H2 @ B1 + G2.dot(M2.T)))
F1_right = H1 @ A + G1.dot(W1.T) - ((H1 @ B2 + G1.dot(M1.T)) @
(H2 @ A + G2.dot(W2.T)))
F1_left = v1 - (H1_B2_plus @ H2_B1_plus)
F1_right = H1_A_plus - (H1_B2_plus @ H2_A_plus)
F1 = solve(F1_left, F1_right)
F2 = H2 @ A + G2.dot(W2.T) - ((H2 @ B1 + G2.dot(M2.T)) @ F1)
F2 = H2_A + G2_W2T - (H2_B1_plus @ F1)

# Update Lambdas

Lambda1 = A - B2 @ F2
Lambda2 = A - B1 @ F1
Pi1 = R1 + (F2.T @ S1.dot(F2))
Pi2 = R2 + (F1.T @ S2.dot(F1))

P1 = (Lambda1.T @ P1 @ Lambda1) + Pi1 - \
((Lambda1.T @ P1 @ B1) + W1 - F2.T.dot(M1)) @ F1
P2 = (Lambda2.T @ P2 @ Lambda2) + Pi2 - \
((Lambda2.T @ P2 @ B2) + W2 - F1.T.dot(M2)) @ F2
# Update P1 using temporaries to avoid repeated computations
Lambda1T = Lambda1.T
Lambda2T = Lambda2.T

LT_P1 = Lambda1T @ P1
LT_P2 = Lambda2T @ P2

LT_P1_L = LT_P1 @ Lambda1
LT_P2_L = LT_P2 @ Lambda2

LT_P1_B1 = LT_P1 @ B1
LT_P2_B2 = LT_P2 @ B2

F2T_M1 = F2.T.dot(M1)
F1T_M2 = F1.T.dot(M2)

# ((Lambda1.T @ P1 @ B1) + W1 - F2.T.dot(M1)) @ F1
term1 = (LT_P1_B1 + W1 - F2T_M1) @ F1
term2 = (LT_P2_B2 + W2 - F1T_M2) @ F2

P1 = LT_P1_L + Pi1 - term1
P2 = LT_P2_L + Pi2 - term2


dd = np.max(np.abs(F10 - F1)) + np.max(np.abs(F20 - F2))

Expand Down