Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Spectral Norm, Adaptive Softmax, faster CPU ops, anomaly detection (NaNs, etc.), Lots of bug fixes, Python 3.7 and CUDA 9.2 support

Choose a tag to compare

@soumith soumith released this 26 Jul 19:09

Table of Contents

  • Breaking Changes
  • New Features
    • Neural Networks
      • Adaptive Softmax, Spectral Norm, etc.
    • Operators
      • torch.bincount, torch.as_tensor, ...
    • torch.distributions
      • Half Cauchy, Gamma Sampling, ...
    • Other
      • Automatic anomaly detection (detecting NaNs, etc.)
  • Performance
    • Faster CPU ops in a wide variety of cases
  • Other improvements
  • Bug Fixes
  • Documentation Improvements

Breaking Changes

  • torch.stft has changed its signature to be consistent with librosa #9497
    • Before: stft(signal, frame_length, hop, fft_size=None, normalized=False, onesided=True, window=None, pad_end=0)
    • After: stft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True)
    • torch.stft is also now using FFT internally and is much faster.
  • torch.slice is removed in favor of the tensor slicing notation #7924
  • torch.arange now does dtype inference: any floating-point argument is inferred to be the default dtype; all integer arguments are inferred to be int64. #7016
  • torch.nn.functional.embedding_bag's old signature embedding_bag(weight, input, ...) is deprecated, embedding_bag(input, weight, ...) (consistent with torch.nn.functional.embedding) should be used instead
  • torch.nn.functional.sigmoid and torch.nn.functional.tanh are deprecated in favor of torch.sigmoid and torch.tanh #8748
  • Broadcast behavior changed in an (very rare) edge case: [1] x [0] now broadcasts to [0] (used to be [1]) #9209

New Features

Neural Networks

  • Adaptive Softmax nn.AdaptiveLogSoftmaxWithLoss #5287

    >>> in_features = 1000
    >>> n_classes = 200
    >>> adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs=[20, 100, 150])
    >>> adaptive_softmax
    AdaptiveLogSoftmaxWithLoss(
      (head): Linear(in_features=1000, out_features=23, bias=False)
      (tail): ModuleList(
        (0): Sequential(
          (0): Linear(in_features=1000, out_features=250, bias=False)
          (1): Linear(in_features=250, out_features=80, bias=False)
        )
        (1): Sequential(
          (0): Linear(in_features=1000, out_features=62, bias=False)
          (1): Linear(in_features=62, out_features=50, bias=False)
        )
        (2): Sequential(
          (0): Linear(in_features=1000, out_features=15, bias=False)
          (1): Linear(in_features=15, out_features=50, bias=False)
        )
      )
    )
    >>> batch = 15
    >>> input = torch.randn(batch, in_features)
    >>> target = torch.randint(n_classes, (batch,), dtype=torch.long)
    >>> # get the log probabilities of target given input, and mean negative log probability loss
    >>> adaptive_softmax(input, target) 
    ASMoutput(output=tensor([-6.8270, -7.9465, -7.3479, -6.8511, -7.5613, -7.1154, -2.9478, -6.9885,
            -7.7484, -7.9102, -7.1660, -8.2843, -7.7903, -8.4459, -7.2371],
           grad_fn=<ThAddBackward>), loss=tensor(7.2112, grad_fn=<MeanBackward1>))
    >>> # get the log probabilities of all targets given input as a (batch x n_classes) tensor
    >>> adaptive_softmax.log_prob(input)  
    tensor([[-2.6533, -3.3957, -2.7069,  ..., -6.4749, -5.8867, -6.0611],
            [-3.4209, -3.2695, -2.9728,  ..., -7.6664, -7.5946, -7.9606],
            [-3.6789, -3.6317, -3.2098,  ..., -7.3722, -6.9006, -7.4314],
            ...,
            [-3.3150, -4.0957, -3.4335,  ..., -7.9572, -8.4603, -8.2080],
            [-3.8726, -3.7905, -4.3262,  ..., -8.0031, -7.8754, -8.7971],
            [-3.6082, -3.1969, -3.2719,  ..., -6.9769, -6.3158, -7.0805]],
           grad_fn=<CopySlices>)
    >>> # predit: get the class that maximize log probaility for each input
    >>> adaptive_softmax.predict(input)  
    tensor([ 8,  6,  6, 16, 14, 16, 16,  9,  4,  7,  5,  7,  8, 14,  3])
  • Add spectral normalization nn.utils.spectral_norm #6929

    >>> # Usage is similar to weight_norm
    >>> convT = nn.ConvTranspose2d(3, 64, kernel_size=3, pad=1)
    >>> # Can specify number of power iterations applied each time, or use default (1)
    >>> convT = nn.utils.spectral_norm(convT, n_power_iterations=2)
    >>>
    >>> # apply to every conv and conv transpose module in a model
    >>> def add_sn(m):
            for name, c in m.named_children():
                 m.add_module(name, add_sn(c))    
             if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                 return nn.utils.spectral_norm(m)
             else:
                 return m
    
    >>> my_model = add_sn(my_model)
  • nn.ModuleDict and nn.ParameterDict containers #8463

  • Add nn.init.zeros_ and nn.init.ones_ #7488

  • Add sparse gradient option to pretrained embedding #7492

  • Add max pooling support to nn.EmbeddingBag #5725

  • Depthwise convolution support for MKLDNN #8782

  • Add nn.FeatureAlphaDropout (featurewise Alpha Dropout layer) #9073

Operators

Distributions

  • Half Cauchy and Half Normal #8411
  • Gamma sampling for CUDA tensors #6855
  • Allow vectorized counts in Binomial Distribution #6720

Misc

Performance

  • Accelerate bernoulli number generation on CPU #7171
  • Enable cuFFT plan caching (80% speed-up in certain cases) #8344
  • Fix unnecessary copying in bernoulli_ #8682
  • Fix unnecessary copying in broadcast #8222
  • Speed-up multidim sum (2x~6x speed-up in certain cases) #8992
  • Vectorize CPU sigmoid (>3x speed-up in most cases) #8612
  • Optimize CPU nn.LeakyReLU and nn.PReLU (2x speed-up) #9206
  • Vectorize softmax and logsoftmax (4.5x speed-up on single core and 1.8x on 10 threads) #7375
  • Speed up nn.init.sparse (10-20x speed-up) #6899

Improvements

Tensor printing

  • Tensor printing now includes requires_grad and grad_fn information #8211
  • Improve number formatting in tensor print #7632
  • Fix scale when printing some tensors #7189
  • Speed up printing of large tensors #6876

Neural Networks

  • NaN is now propagated through many activation functions #8033
  • Add non_blocking option to nn.Module.to #7312
  • Loss modules now allow target to require gradient #8460
  • Add pos_weight argument to nn.BCEWithLogitsLoss #6856
  • Support grad_clip for parameters on different devices #9302
  • Removes the requirement that input sequences to pad_sequence have to be sorted #7928
  • stride argument for max_unpool1d, max_unpool2d, max_unpool3d now defaults to kernel_size #7388
  • Allowing calling grad mode context managers (e.g., torch.no_grad, torch.enable_grad) as decorators #7737
  • torch.optim.lr_scheduler._LRSchedulers __getstate__ include optimizer info #7757
  • Add support for accepting Tensor as input in clip_grad_* functions #7769
  • Return NaN in max_pool/adaptive_max_pool for NaN inputs #7670
  • nn.EmbeddingBag can now handle empty bags in all modes #7389
  • torch.optim.lr_scheduler.ReduceLROnPlateau is now serializable #7201
  • Allow only tensors of floating point dtype to require gradients #7034 and #7185
  • Allow resetting of BatchNorm running stats and cumulative moving average #5766
  • Set the gradient of LP-Pooling to zero if the sum of all input elements to the power of p is zero #6766

Operators

Distributions

  • Always enable grad when calculating lazy_property #7708

Sparse Tensor

  • Add log1p for sparse tensor #8969
  • Better support for adding zero-filled sparse tensors #7479

Data Parallel

  • Allow modules that return scalars in nn.DataParallel #7973
  • Allow nn.parallel.parallel_apply to take in a list/tuple of tensors #8047

Misc

  • torch.Size can now accept PyTorch scalars #5676
  • Move torch.utils.data.dataset.random_split to torch.utils.data.random_split, and torch.utils.data.dataset.Subset to torch.utils.data.Subset #7816
  • Add serialization for torch.device #7713
  • Allow copy.deepcopy of torch.(int/float/...)* dtype objects #7699
  • torch.load can now take a torch.device as map location #7339

Bug Fixes

  • Fix nn.BCELoss sometimes returning negative results #8147
  • Fix tensor._indices on scalar sparse tensor giving wrong result #8197
  • Fix backward of tensor.as_strided not working properly when input has overlapping memory #8721
  • Fix x.pow(0) gradient when x contains 0 #8945
  • Fix CUDA torch.svd and torch.eig returning wrong results in certain cases #9082
  • Fix nn.MSELoss having low precision #9287
  • Fix segmentation fault when calling torch.Tensor.grad_fn #9292
  • Fix torch.topk returning wrong results when input isn't contiguous #9441
  • Fix segfault in convolution on CPU with large inputs / dilation #9274
  • Fix avg_pool2/3d count_include_pad having default value False (should be True) #8645
  • Fix nn.EmbeddingBag's max_norm option #7959
  • Fix returning scalar input in Python autograd function #7934
  • Fix THCUNN SpatialDepthwiseConvolution assuming contiguity #7952
  • Fix bug in seeding random module in DataLoader #7886
  • Don't modify variables in-place for torch.einsum #7765
  • Make return uniform in lbfgs step #7586
  • The return value of uniform.cdf() is now clamped to [0..1] #7538
  • Fix advanced indexing with negative indices #7345
  • CUDAGenerator will not initialize on the current device anymore, which will avoid unnecessary memory allocation on GPU:0 #7392
  • Fix tensor.type(dtype) not preserving device #7474
  • Batch sampler should return the same results when used alone or in dataloader with num_workers > 0 #7265
  • Fix broadcasting error in LogNormal, TransformedDistribution #7269
  • Fix torch.max and torch.min on CUDA in presence of NaN #7052
  • Fix torch.tensor device-type calculation when used with CUDA #6995
  • Fixed a missing '=' in nn.LPPoolNd repr function #9629

Documentation