This tutorial walks you through exactly how to create the code you shared, line by line. We’ll start from zero and build the script incrementally so you understand why each part exists and how it works. By the end, you’ll have the full working code and the knowledge to modify it.
Prerequisites (Do these first)
Install Python (if you don’t have it): Download from python.org (version 3.8 or newer).
Setup project structure (open your terminal/command prompt and run):
mkdir python-neuralnet
cd python-neuralnet
python -m venv venv
# OR
# python3 -m venv venv
# Activate:
# macOS/Linux → source venv/bin/activate
# Windows → venv\Scripts\activate
Install the required libraries (open your terminal/command prompt and run):
pip install pandas numpy matplotlib
Download the MNIST test dataset (mnist_test.csv):
- Easiest option: Direct download → https://python-course.eu/data/mnist/mnist_test.csv
- Alternative (very popular): Kaggle dataset “MNIST in CSV” → https://github.com/phoebetronic/mnist (download the mnist_test.csv file).
- Save the file in the same folder where you will create your Python script (or note the full path).
Format reminder: The CSV has no header. First column = digit label (0–9), next 784 columns = pixel values (0–255) for a 28×28 image.
Recommended environment: Use Jupyter Notebook or VS Code with a .py file. Jupyter makes plt.show() very convenient.
Task 1: Read and Display Data
You are given a CSV file containing MNIST digit data, where each row starts with a label (digit 1–9) followed by 784 pixel values (28×28 image). Write a Python program that:
- Reads the CSV file into a suitable data structure.
- For each digit from 1 to 9, finds three different images of that digit.
- Plots these images in a grid with 3 rows and 9 columns, where each column corresponds to a digit (1–9) and each row shows a different instance.
- The plot should not display axis ticks, and each column should be titled with the corresponding digit.
Solution Task 1: Read and Display Data
Solution Task 1: Read and Display Data
Step 1: Create a new Python file and add the imports
Create a file called display_data.py (or a new notebook cell).
import pandas as pd # For reading the CSV easily
import numpy as np # For reshaping the pixel list into a 28x28 image
import matplotlib.pyplot as plt # For creating the grid of imagesWhy these libraries?
- pandas: Reads the large CSV into a table.
- numpy: Converts the flat list of 784 numbers into a 2D image array.
- matplotlib: Draws the actual pictures in a grid.
Step 2: Load the data
path = r"mnist_test.csv" # Change if your file is in a different folder
df = pd.read_csv(path, header=None) # header=None because the CSV has no column names
datalist = df.values.tolist() # Convert to list of lists (easier to loop through rows)What just happened?
- pd.read_csv loads all 10,000 test rows.
- .values.tolist() turns the DataFrame into a plain Python list (each inner list = one handwritten digit + its pixels). This makes the next step simpler.
Step 3: Prepare a dictionary to collect exactly 3 examples per digit (1–9)
# Prepare a dictionary to collect up to 3 instances for each digit 1-9
instances = {digit: [] for digit in range(1, 10)}Why only 1–9 and not 0? The code deliberately skips digit 0 (you can change range(1,10) to range(10) if you want 0–9).
Step 4: Loop through the data and collect the samples
# Search for up to 3 instances of each digit in the dataset
for row in datalist:
label = row[0] # First number in the row is the digit label
if label in instances and len(instances[label]) < 3:
instances[label].append(row[1:]) # Store only the 784 pixel values (remove label)
# Stop early if all digits have 3 instances
if all(len(v) == 3 for v in instances.values()):
breakWhy this loop?
- We only need 27 images total (3 × 9 digits).
- The early-stop if makes the code faster (no need to scan all 10,000 rows).
- row[1:] throws away the label because we already know which digit it is.
Step 5: Create the 3-row × 9-column grid of images
# Create a figure with 3 rows (instances) and 9 columns (digits 1-9)
fig, axis = plt.subplots(3, 9, figsize=(18, 6))Explanation:
- 3, 9 → 3 examples high, 9 digits wide.
- figsize=(18, 6) makes the whole plot wide and not too tall.
Step 6: Fill the grid with the actual digit images
for digit in range(1, 10): # Loop through digits 1 to 9
for j in range(3): # Loop through the 3 examples
ax = axis[j, digit-1] # Pick the correct subplot (column = digit-1)
if j < len(instances[digit]):
# Reshape the flat list of 784 numbers into a 28x28 image
image = np.array(instances[digit][j]).reshape(28, 28)
ax.imshow(image, cmap='gray') # Show as grayscale image
if j == 0:
ax.set_title(f"{digit}") # Title only on the top row
else:
ax.axis('off') # If we somehow don't have 3 images, hide the empty spot
ax.axis('off') # Hide axis numbers and ticks for a clean lookKey tricks explained:
- reshape(28, 28): Turns the 784 pixels back into the original square image.
- cmap=’gray’: Makes the handwritten digits look like classic MNIST (black background, white digits).
- if j == 0: Only the first row gets the digit number as a title.
- ax.axis(‘off’): Removes the distracting borders and numbers.
Step 7: Finalize and display the plot
plt.tight_layout() # Automatically adjusts spacing so nothing overlaps
plt.show() # Shows the grid of imagesFull Final Code (Copy-Paste Ready)
import pandas as pd # Import pandas for data manipulation
import numpy as np # Import numpy for numerical operations
import matplotlib.pyplot as plt # Import matplotlib for plotting images
path = r"mnist_test.csv" # Path to the MNIST CSV data file
df = pd.read_csv(path, header=None) # Read the CSV file into a pandas DataFrame without headers
datalist = df.values.tolist() # Convert the DataFrame to a list of lists
# Prepare a dictionary to collect up to 3 instances for each digit 1-9
instances = {digit: [] for digit in range(1, 10)}
# Search for up to 3 instances of each digit in the dataset
for row in datalist:
label = row[0]
if label in instances and len(instances[label]) < 3:
instances[label].append(row[1:])
# Stop early if all digits have 3 instances
if all(len(v) == 3 for v in instances.values()):
break
# Create a figure with 3 rows (instances) and 9 columns (digits 1-9)
fig, axis = plt.subplots(3, 9, figsize=(18, 6))
for digit in range(1, 10):
for j in range(3):
ax = axis[j, digit-1]
if j < len(instances[digit]):
image = np.array(instances[digit][j]).reshape(28, 28)
ax.imshow(image, cmap='gray')
if j == 0:
ax.set_title(f"{digit}") # Set column title only for the first row
else:
ax.axis('off') # Hide axis if no image
ax.axis('off') # Hide axis ticks
plt.tight_layout()
plt.show() # Display the plot with the imagesHow to run it
- Save the code in a .py file (or Jupyter notebook cell).
- Make sure mnist_test.csv is in the same folder.
- Run the script → you should see a beautiful 3×9 grid showing three examples of each digit 1–9.
python display_data.py
Next-level customizations you can try:
- Change range(1, 10) to range(10) to include digit 0.
- Increase to 5 examples: change 3 everywhere to 5 and plt.subplots(5, 9, …).
- Add color: cmap=’viridis’ or any other matplotlib colormap.