#!/usr/bin/env python3

import argparse

import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

matplotlib.use('svg')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Plot layer weights.')
    parser.add_argument('model', metavar='MODEL', help='model')
    parser.add_argument('output', metavar='OUTPUT', help='output image')
    parser.add_argument(
        '-n',
        '--normalized',
        action='store_true',
        help='plot normalized layer weights')
    parser.add_argument(
        '-f',
        '--format',
        default='svg',
        help='format as matplotlib backend')
    parser.add_argument(
        '-t',
        '--title',
        default=None,
        help='plot title')

    args = parser.parse_args()

    model = torch.jit.load(args.model, map_location="cpu")

    parameter_names = []
    for name, _ in model.named_parameters():
        if "layer_weights" in name:
            parameter_names.append(name)

    parameter_names.sort()

    names = []
    weights = []
    for parameter_name in parameter_names:
        layer = parameter_name.split("|")[1][:-len("_classifier")]
        names.append(layer)

        tensor = getattr(model, parameter_name)
        tensor = tensor.softmax(-1) if args.normalized else tensor
        weights.append(tensor.detach().numpy())

    plt.plot(np.array(weights).transpose())
    plt.legend(names)
    if args.title:
        plt.title(args.title)
    plt.xlabel('Layer')
    plt.ylabel('Layer weight')
    plt.savefig(args.output, format=args.format)
