forked from apoorvnandan/tensor.h
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_mnist_csv.py
More file actions
40 lines (31 loc) · 1.42 KB
/
create_mnist_csv.py
File metadata and controls
40 lines (31 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torchvision
import torchvision.transforms as transforms
import csv
import numpy as np
from torch.utils.data import DataLoader
# Assuming you have already defined the transforms as in your snippet
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# Load datasets
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
def dataset_to_csv(dataset, filename):
with open(filename, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
# Write header - 784 columns for pixels + 10 for one-hot encoded label
# header = [f'pixel{i}' for i in range(28*28)] + [f'label_{j}' for j in range(10)]
# writer.writerow(header)
for image, label in dataset:
# Flatten the image
image_flat = image.numpy().flatten()
# One-hot encode the label
label_one_hot = np.zeros(10)
label_one_hot[label] = 1
# Combine image data with one-hot encoded label
row = np.concatenate([image_flat, label_one_hot])
writer.writerow(row)
# Convert trainset to CSV
dataset_to_csv(trainset, 'mnist_train.csv')
# Convert testset to CSV
dataset_to_csv(testset, 'mnist_test.csv')
print("Datasets have been converted to CSV format.")