Introduction to Neural Networks with Pytorch

This tutorial walks you through exactly how to create the code you shared, line by line. We’ll start from zero and build the script incrementally so you understand why each part exists and how it works. By the end, you’ll have the full working code and the knowledge to modify it.

Prerequisites (Do these first)

Install Python (if you don’t have it): Download from python.org (version 3.8 or newer).

Setup project structure (open your terminal/command prompt and run):

mkdir python-neuralnet
cd python-neuralnet
python -m venv venv
# OR
# python3 -m venv venv

# Activate:
#   macOS/Linux → source venv/bin/activate
#   Windows    → venv\Scripts\activate

Install the required libraries (open your terminal/command prompt and run):

pip install pandas numpy matplotlib

Download the MNIST test dataset (mnist_test.csv):

    Format reminder: The CSV has no header. First column = digit label (0–9), next 784 columns = pixel values (0–255) for a 28×28 image.

    Recommended environment: Use Jupyter Notebook or VS Code with a .py file. Jupyter makes plt.show() very convenient.

    Task 1: Read and Display Data

    You are given a CSV file containing MNIST digit data, where each row starts with a label (digit 1–9) followed by 784 pixel values (28×28 image). Write a Python program that:

    • Reads the CSV file into a suitable data structure.
    • For each digit from 1 to 9, finds three different images of that digit.
    • Plots these images in a grid with 3 rows and 9 columns, where each column corresponds to a digit (1–9) and each row shows a different instance.
    • The plot should not display axis ticks, and each column should be titled with the corresponding digit.
    Solution Task 1: Read and Display Data

    Solution Task 1: Read and Display Data

    Step 1: Create a new Python file and add the imports

    Create a file called display_data.py (or a new notebook cell).

    import pandas as pd      # For reading the CSV easily
    import numpy as np       # For reshaping the pixel list into a 28x28 image
    import matplotlib.pyplot as plt   # For creating the grid of images

    Why these libraries?

    • pandas: Reads the large CSV into a table.
    • numpy: Converts the flat list of 784 numbers into a 2D image array.
    • matplotlib: Draws the actual pictures in a grid.

    Step 2: Load the data

    path = r"mnist_test.csv"          # Change if your file is in a different folder
    df = pd.read_csv(path, header=None)   # header=None because the CSV has no column names
    datalist = df.values.tolist()     # Convert to list of lists (easier to loop through rows)

    What just happened?

    • pd.read_csv loads all 10,000 test rows.
    • .values.tolist() turns the DataFrame into a plain Python list (each inner list = one handwritten digit + its pixels). This makes the next step simpler.

    Step 3: Prepare a dictionary to collect exactly 3 examples per digit (1–9)

    # Prepare a dictionary to collect up to 3 instances for each digit 1-9
    instances = {digit: [] for digit in range(1, 10)}

    Why only 1–9 and not 0? The code deliberately skips digit 0 (you can change range(1,10) to range(10) if you want 0–9).

    Step 4: Loop through the data and collect the samples

    # Search for up to 3 instances of each digit in the dataset
    for row in datalist:
        label = row[0]                    # First number in the row is the digit label
        if label in instances and len(instances[label]) < 3:
            instances[label].append(row[1:])   # Store only the 784 pixel values (remove label)
        
        # Stop early if all digits have 3 instances
        if all(len(v) == 3 for v in instances.values()):
            break

    Why this loop?

    • We only need 27 images total (3 × 9 digits).
    • The early-stop if makes the code faster (no need to scan all 10,000 rows).
    • row[1:] throws away the label because we already know which digit it is.

    Step 5: Create the 3-row × 9-column grid of images

    # Create a figure with 3 rows (instances) and 9 columns (digits 1-9)
    fig, axis = plt.subplots(3, 9, figsize=(18, 6))

    Explanation:

    • 3, 9 → 3 examples high, 9 digits wide.
    • figsize=(18, 6) makes the whole plot wide and not too tall.

    Step 6: Fill the grid with the actual digit images

    for digit in range(1, 10):           # Loop through digits 1 to 9
        for j in range(3):               # Loop through the 3 examples
            ax = axis[j, digit-1]        # Pick the correct subplot (column = digit-1)
            
            if j < len(instances[digit]):
                # Reshape the flat list of 784 numbers into a 28x28 image
                image = np.array(instances[digit][j]).reshape(28, 28)
                ax.imshow(image, cmap='gray')   # Show as grayscale image
                
                if j == 0:
                    ax.set_title(f"{digit}")    # Title only on the top row
            else:
                ax.axis('off')   # If we somehow don't have 3 images, hide the empty spot
            
            ax.axis('off')       # Hide axis numbers and ticks for a clean look

    Key tricks explained:

    • reshape(28, 28): Turns the 784 pixels back into the original square image.
    • cmap=’gray’: Makes the handwritten digits look like classic MNIST (black background, white digits).
    • if j == 0: Only the first row gets the digit number as a title.
    • ax.axis(‘off’): Removes the distracting borders and numbers.

    Step 7: Finalize and display the plot

    plt.tight_layout()   # Automatically adjusts spacing so nothing overlaps
    plt.show()           # Shows the grid of images

    Full Final Code (Copy-Paste Ready)

    import pandas as pd  # Import pandas for data manipulation
    import numpy as np  # Import numpy for numerical operations
    import matplotlib.pyplot as plt  # Import matplotlib for plotting images
    
    path = r"mnist_test.csv"  # Path to the MNIST CSV data file
    df = pd.read_csv(path, header=None)  # Read the CSV file into a pandas DataFrame without headers
    
    datalist = df.values.tolist()  # Convert the DataFrame to a list of lists
    
    # Prepare a dictionary to collect up to 3 instances for each digit 1-9
    instances = {digit: [] for digit in range(1, 10)}
    
    # Search for up to 3 instances of each digit in the dataset
    for row in datalist:
        label = row[0]
        if label in instances and len(instances[label]) < 3:
            instances[label].append(row[1:])
        # Stop early if all digits have 3 instances
        if all(len(v) == 3 for v in instances.values()):
            break
    
    # Create a figure with 3 rows (instances) and 9 columns (digits 1-9)
    fig, axis = plt.subplots(3, 9, figsize=(18, 6))
    
    for digit in range(1, 10):
        for j in range(3):
            ax = axis[j, digit-1]
            if j < len(instances[digit]):
                image = np.array(instances[digit][j]).reshape(28, 28)
                ax.imshow(image, cmap='gray')
                if j == 0:
                    ax.set_title(f"{digit}")  # Set column title only for the first row
            else:
                ax.axis('off')  # Hide axis if no image
            ax.axis('off')  # Hide axis ticks
    
    plt.tight_layout()
    plt.show()  # Display the plot with the images

    How to run it

    1. Save the code in a .py file (or Jupyter notebook cell).
    2. Make sure mnist_test.csv is in the same folder.
    3. Run the script → you should see a beautiful 3×9 grid showing three examples of each digit 1–9.
    python display_data.py

    Next-level customizations you can try:

    • Change range(1, 10) to range(10) to include digit 0.
    • Increase to 5 examples: change 3 everywhere to 5 and plt.subplots(5, 9, …).
    • Add color: cmap=’viridis’ or any other matplotlib colormap.

    Task 2: TODO

    Leave a Reply