Tools for visualizing neural nets.
⮕ renders how embeddings map layer to layer in simple networks
⮕ supports 2D and 3D embeddings
⮕ render training process as a movie
⮕ render as image, video, or interactive webpage (plotly or threejs)
git clone [email protected]:phillipi/plot-net.git
cd plot-net
python3 -m venv venv
source venv/bin/activate
pip install torch matplotlib numpy plotly
python plot-net.py --which_dataset <dataset_name> --which_model <model_name> --d <2,3> --train <True, False> --viz_type <static, movie>
dataset_name: Seemk_datasetindatasets.pyfor options.model_name: Seemk_modelinmodels.pyfor options.d: dimensionality of the data (and width of the model); supported options are 2 or 3; most datasets and models should automatically scale to the specified d.train: train the net or just run it from init?viz_type:staticruns one forward pass and outputs an image of the embeddings;training_movietrains the model on the data and outputs a movie of the embeddings over training iters.renderer:matplotlib,plotly, orthreejs- See
plot-net.pyfor additional command line arguments.
For threejs visualization, the webpage files are saved to ./threejs/. You may need an http server. You can run one locally like:
cd threejs
python3 -m http.server 8000
Then navigate your browser to http://localhost:8000/.
python plot-net.py --which_dataset gaussian_data --which_model linear --d 2 --viz_type static
Output:
2) Train a model on some 2D data and visualize the evolution of the embeddings over iters as a movie:
python plot-net.py --which_dataset binary_classification --which_model MySimpleNet --d 2 --train True --viz_type movie --N_viz_iter 60 --N_train_iter_per_viz 150
Output (click to play the video):
3) Train a model on some 3D data and visualize the evolution of the embeddings over iters as a movie, with rotating camera:
python plot-net.py --which_dataset ternary_classification --which_model SimpleResnet --d 3 --train True --viz_type movie --rotate_camera True
Output (click to play the video):
python plot-net.py --which_dataset spiral --which_model diffusion --d 2 --train False --viz_type static --renderer threejs
Output (click to load interactive webpage):