Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions examples/federated_learning_with_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@

import phe as paillier

seed = 42
seed = 43
np.random.seed(seed)


Expand All @@ -86,7 +86,6 @@ def get_data(n_clients):
diabetes = load_diabetes()
y = diabetes.target
X = diabetes.data

# Add constant to emulate intercept
X = np.c_[X, np.ones(X.shape[0])]

Expand Down Expand Up @@ -193,30 +192,21 @@ def encrypted_gradient(self, sum_to=None):
return encrypted_gradient


def federated_learning(n_iter, eta, n_clients, key_length):
def federated_learning(X, y, X_test, y_test, config):
n_clients = config['n_clients']
n_iter = config['n_iter']
names = ['Hospital {}'.format(i) for i in range(1, n_clients + 1)]

X, y, X_test, y_test = get_data(n_clients=n_clients)

# Instantiate the server and generate private and public keys
# NOTE: using smaller keys sizes wouldn't be cryptographically safe
server = Server(key_length=key_length)
server = Server(key_length=config['key_length'])

# Instantiate the clients.
# Each client gets the public key at creation and its own local dataset
clients = []
for i in range(n_clients):
clients.append(Client(names[i], X[i], y[i], server.pubkey))

# Each client trains a linear regressor on its own data
print('Error (MSE) that each client gets on test set by '
'training only on own local data:')
for c in clients:
c.fit(n_iter, eta)
y_pred = c.predict(X_test)
mse = mean_square_error(y_pred, y_test)
print('{:s}:\t{:.2f}'.format(c.name, mse))

# The federated learning with gradient descent
print('Running distributed gradient aggregation for {:d} iterations'
.format(n_iter))
Expand All @@ -232,14 +222,45 @@ def federated_learning(n_iter, eta, n_clients, key_length):

# Take gradient steps
for c in clients:
c.gradient_step(aggr, eta)
c.gradient_step(aggr, config['eta'])

print('Error (MSE) that each client gets after running the protocol:')
for c in clients:
y_pred = c.predict(X_test)
mse = mean_square_error(y_pred, y_test)
print('{:s}:\t{:.2f}'.format(c.name, mse))


def local_learning(X, y, X_test, y_test, config):
n_clients = config['n_clients']
names = ['Hospital {}'.format(i) for i in range(1, n_clients + 1)]

# Instantiate the clients.
# Each client gets the public key at creation and its own local dataset
clients = []
for i in range(n_clients):
clients.append(Client(names[i], X[i], y[i], None))

# Each client trains a linear regressor on its own data
print('Error (MSE) that each client gets on test set by '
'training only on own local data:')
for c in clients:
c.fit(config['n_iter'], config['eta'])
y_pred = c.predict(X_test)
mse = mean_square_error(y_pred, y_test)
print('{:s}:\t{:.2f}'.format(c.name, mse))


if __name__ == '__main__':
# Set learning, data split, and security params
federated_learning(n_iter=50, eta=0.01, n_clients=3, key_length=1024)
config = {
'n_clients': 3,
'key_length': 1024,
'n_iter': 50,
'eta': 0.01,
}
# load data, train/test split and split training data between clients
X, y, X_test, y_test = get_data(n_clients=config['n_clients'])
# first each hospital learns a model on its respective dataset for comparison.
local_learning(X, y, X_test, y_test, config)
# and now the full glory of federated learning
federated_learning(X, y, X_test, y_test, config)