diff --git a/quantecon/_lqnash.py b/quantecon/_lqnash.py index 56f7b8a4..5698c7f6 100644 --- a/quantecon/_lqnash.py +++ b/quantecon/_lqnash.py @@ -4,7 +4,7 @@ def nnash(A, B1, B2, R1, R2, Q1, Q2, S1, S2, W1, W2, M1, M2, beta=1.0, tol=1e-8, max_iter=1000): - r""" + """ Compute the limit of a Nash linear quadratic dynamic game. In this problem, player i minimizes @@ -79,7 +79,7 @@ def nnash(A, B1, B2, R1, R2, Q1, Q2, S1, S2, W1, W2, M1, M2, """ # == Unload parameters and make sure everything is an array == # params = A, B1, B2, R1, R2, Q1, Q2, S1, S2, W1, W2, M1, M2 - params = map(np.asarray, params) + params = tuple(map(np.asarray, params)) A, B1, B2, R1, R2, Q1, Q2, S1, S2, W1, W2, M1, M2 = params # == Multiply A, B1, B2 by sqrt(beta) to enforce discounting == # @@ -99,6 +99,15 @@ def nnash(A, B1, B2, R1, R2, Q1, Q2, S1, S2, W1, W2, M1, M2, else: k_2 = B2.shape[1] + + # Precompute transposes that are reused + B1T = B1.T + B2T = B2.T + W1T = W1.T + W2T = W2.T + M1T = M1.T + M2T = M2.T + v1 = np.eye(k_1) v2 = np.eye(k_2) P1 = np.zeros((n, n)) @@ -111,28 +120,70 @@ def nnash(A, B1, B2, R1, R2, Q1, Q2, S1, S2, W1, W2, M1, M2, F10 = F1 F20 = F2 - G2 = solve((B2.T @ P2 @ B2)+Q2, v2) - G1 = solve((B1.T @ P1 @ B1)+Q1, v1) - H2 = G2 @ B2.T @ P2 - H1 = G1 @ B1.T @ P1 + # Solve small k x k systems for G1, G2 (inverse-like) + S2mat = (B2T @ P2 @ B2) + Q2 + S1mat = (B1T @ P1 @ B1) + Q1 + G2 = solve(S2mat, v2) + G1 = solve(S1mat, v1) + + # Compute H terms with a cheaper multiplication order + # H1 = G1 @ (B1.T @ P1) + B1T_P1 = B1T @ P1 + B2T_P2 = B2T @ P2 + H1 = G1 @ B1T_P1 + H2 = G2 @ B2T_P2 + + # Reusable intermediate products + H1_B2 = H1 @ B2 + H2_B1 = H2 @ B1 + G1_M1T = G1.dot(M1T) + G2_M2T = G2.dot(M2T) + H1_B2_plus = H1_B2 + G1_M1T + H2_B1_plus = H2_B1 + G2_M2T + + H1_A = H1 @ A + H2_A = H2 @ A + G1_W1T = G1.dot(W1T) + G2_W2T = G2.dot(W2T) + H2_A_plus = H2_A + G2_W2T + H1_A_plus = H1_A + G1_W1T # break up the computation of F1, F2 - F1_left = v1 - ((H1 @ B2 + G1.dot(M1.T)) @ - (H2 @ B1 + G2.dot(M2.T))) - F1_right = H1 @ A + G1.dot(W1.T) - ((H1 @ B2 + G1.dot(M1.T)) @ - (H2 @ A + G2.dot(W2.T))) + F1_left = v1 - (H1_B2_plus @ H2_B1_plus) + F1_right = H1_A_plus - (H1_B2_plus @ H2_A_plus) F1 = solve(F1_left, F1_right) - F2 = H2 @ A + G2.dot(W2.T) - ((H2 @ B1 + G2.dot(M2.T)) @ F1) + F2 = H2_A + G2_W2T - (H2_B1_plus @ F1) + + # Update Lambdas Lambda1 = A - B2 @ F2 Lambda2 = A - B1 @ F1 Pi1 = R1 + (F2.T @ S1.dot(F2)) Pi2 = R2 + (F1.T @ S2.dot(F1)) - P1 = (Lambda1.T @ P1 @ Lambda1) + Pi1 - \ - ((Lambda1.T @ P1 @ B1) + W1 - F2.T.dot(M1)) @ F1 - P2 = (Lambda2.T @ P2 @ Lambda2) + Pi2 - \ - ((Lambda2.T @ P2 @ B2) + W2 - F1.T.dot(M2)) @ F2 + # Update P1 using temporaries to avoid repeated computations + Lambda1T = Lambda1.T + Lambda2T = Lambda2.T + + LT_P1 = Lambda1T @ P1 + LT_P2 = Lambda2T @ P2 + + LT_P1_L = LT_P1 @ Lambda1 + LT_P2_L = LT_P2 @ Lambda2 + + LT_P1_B1 = LT_P1 @ B1 + LT_P2_B2 = LT_P2 @ B2 + + F2T_M1 = F2.T.dot(M1) + F1T_M2 = F1.T.dot(M2) + + # ((Lambda1.T @ P1 @ B1) + W1 - F2.T.dot(M1)) @ F1 + term1 = (LT_P1_B1 + W1 - F2T_M1) @ F1 + term2 = (LT_P2_B2 + W2 - F1T_M2) @ F2 + + P1 = LT_P1_L + Pi1 - term1 + P2 = LT_P2_L + Pi2 - term2 + dd = np.max(np.abs(F10 - F1)) + np.max(np.abs(F20 - F2))