This project focuses on EEG classification using deep learning models. Various architectures were explored, including simple MLPs, CNN-based EEGNet, LSTMs, and residual MLPs. The final model chosen was RobustEEGClassifier, which leverages residual blocks to improve feature learning and stability.
The dataset used is EEG Eye State, provided in ARFF format. It contains 14 EEG channel readings and a binary classification label indicating whether the subject's eyes were open or closed.
- Python 3.8+
- PyTorch
- scikit-learn
- pandas
- matplotlib
- seaborn
- tqdm
- shap
- Clone the repository:
git clone https://github.com/yourusername/EEG-Eye-State.git cd EEG-Eye-State
- Data is read from
eeg+eye+state/EEG Eye State.arff. - Standard scaling is applied to the EEG channel values.
- The dataset is split into training, validation, and test sets (80-10-10 ratio).
- Data is loaded using
torch.utils.data.DataLoader.
To better understand the data, the following visualizations were added:
- Class Distribution: A bar plot showing the number of samples per class (eyes open vs. closed).
- EEG Channel Correlation Matrix: A heatmap displaying correlations between different EEG channels.
Several models were implemented and tested:
- A simple feedforward network with batch normalization and dropout.
- Uses temporal and depthwise convolutions inspired by EEGNet for EEG signal processing.
- Incorporates skip connections to enhance gradient flow in deeper MLP architectures.
- Uses LSTMs to capture temporal dependencies in EEG signals.
- Based on residual MLP architecture with Layer Normalization.
- Stacks multiple
ResidualBlocklayers for better feature representation. - Introduces a dropout mechanism to reduce overfitting.
The training process was implemented in train_tqdm.py, which:
- Uses
CrossEntropyLossas the loss function. - Optimizes the model using the Adam optimizer with a learning rate of
0.0021. - Monitors progress using
tqdmfor better visualization.
num_epochs = 200batch_size = 64learning_rate = 0.0021weight_decay = 1e-4
The evaluation function computes:
- Test loss
- Accuracy
The trained RobustEEGClassifier achieved good classification accuracy, outperforming other models tested.
To analyze model performance, the following plots were added:
- Training Loss & Accuracy Curve: Plots loss and accuracy over epochs to detect overfitting.
- Train vs Validation Accuracy Curve: Compares train and validation accuracy per epoch to check for overfitting.
- Confusion Matrix: Visualizes the classification performance.
- ROC Curve: Evaluates performance, especially for imbalanced datasets.
To train and evaluate the model, run:
python main.pyEnsure that all dependencies are installed before execution.