import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
import matplotlib.pyplot as plt
# Load dataset
dataset_name = 'cats_vs_dogs'
(data_train, data_test), dataset_info = tfds.load(
dataset_name,
split=['train[:80%]', 'train[80%:]'],
as_supervised=True, # Include labels
with_info=True # Include dataset info
# Data preprocessing
IMG_SIZE = 150
def preprocess_image(image, label):
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
image = image / 255.0 # Normalize pixel values
return image, label
train_dataset = data_train.map(preprocess_image).shuffle(1000).batch(32).prefetch(1)
test_dataset = data_test.map(preprocess_image).batch(32).prefetch(1)
# Visualize a few samples
def plot_samples(dataset, n_samples=5):
plt.figure(figsize=(12, 8))
for i, (image, label) in enumerate(dataset.take(n_samples)):
ax = plt.subplot(1, n_samples, i + 1)
plt.imshow(image.numpy())
plt.title('Cat' if label.numpy() == 0 else 'Dog')
plt.axis('off')
plt.show()
plot_samples(data_train.map(preprocess_image))
# Build CNN model
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 3)),
MaxPooling2D(2, 2),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Conv2D(128, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Flatten(),
Dense(512, activation='relu'),
Dropout(0.5),
Dense(1, activation='sigmoid') # Binary classification
])
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
# Train the model
history = model.fit(
train_dataset,
validation_data=test_dataset,
epochs=10
)
# Evaluate the model
loss, accuracy = model.evaluate(test_dataset)
print(f"Test Accuracy: {accuracy:.2f}")
# Plot training history
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.figure(figsize=(12, 8))
plt.plot(epochs, acc, 'r', label='Training Accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.figure(figsize=(12, 8))
plt.plot(epochs, loss, 'r', label='Training Loss')
plt.plot(epochs, val_loss, 'b', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()