import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten, Dropout
# Load and preprocess the dataset
(X_train, y_train), (X_test, y_test) = keras.datasets.cifar10.load_data()
X_train = X_train.astype("float32") / 255.0
X_test = X_test.astype("float32") / 255.0
# Data augmentation
datagen = ImageDataGenerator(
rotation_range=15,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True
datagen.fit(X_train)
# Standardize the data (z-score normalization)
mean = np.mean(X_train, axis=(0, 1, 2, 3))
std = np.std(X_train, axis=(0, 1, 2, 3))
X_train = (X_train - mean) / std
X_test = (X_test - mean) / std
# Learning rate schedule function
def lr_schedule(epoch):
initial_lr = 0.001
if epoch > 5:
return initial_lr * 0.1
return initial_lr
lr_scheduler = LearningRateScheduler(lr_schedule)
# Define the CNN architecture
model = keras.Sequential([
layers.Conv2D(32, (3, 3), activation="relu", input_shape=(32, 32, 3)),
layers.BatchNormalization(),
layers.Conv2D(32, (3, 3), activation="relu"),
layers.BatchNormalization(),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, (3, 3), activation="relu"),
layers.BatchNormalization(),
layers.Conv2D(64, (3, 3), activation="relu"),
layers.BatchNormalization(),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dense(128, activation="relu"),
layers.BatchNormalization(),
layers.Dropout(0.5),
layers.Dense(10) # No activation here, logits output for SparseCategoricalCrossentropy
])
# Compile the model
model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"]
# Early stopping to prevent overfitting
early_stopping = EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True)
# Train the model with data augmentation
history = model.fit(
datagen.flow(X_train, y_train, batch_size=64),
epochs=20,
validation_data=(X_test, y_test),
callbacks=[early_stopping, lr_scheduler]
# Evaluate the model
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=1)
print(f"Test Accuracy: {test_acc:.2f}")
print(f"Test Loss: {test_loss:.2f}")
# Save the model
model.save("cifar10_model.h5")
print("Model saved as cifar10_model.h5")
# Plot training results
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history["accuracy"], label="Train")
plt.plot(history.history["val_accuracy"], label="Validation")
plt.legend()
plt.title("Accuracy")
plt.subplot(1, 2, 2)
plt.plot(history.history["loss"], label="Train")
plt.plot(history.history["val_loss"], label="Validation")
plt.legend()
plt.title("Loss")
plt.show()