Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,28 +738,36 @@ def scope():

def test_no_grad(self):
x = torch.ones(5, 5, requires_grad=True)
y = Variable(torch.ones(5, 5) * 4)
with torch.no_grad():
w = x + y
y = torch.ones(5, 5) * 4

@torch.no_grad()
def adder(x, y):
return x + y

z = adder(x, y)
def viewer(x, y):
return x[1]

self.assertFalse(w.requires_grad)
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
self.assertIsNone(w.grad_fn)
self.assertFalse(z.requires_grad)
self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
self.assertIsNone(z.grad_fn)
for binary_op in (adder, viewer):
with torch.no_grad():
w = binary_op(x, y)

# test nested decorator and with-statement on no_grad
with torch.no_grad():
self.assertFalse(torch.is_grad_enabled())
w = adder(x, y)
self.assertFalse(torch.is_grad_enabled())
@torch.no_grad()
def decorated(x, y):
return binary_op(x, y)

z = decorated(x, y)

self.assertFalse(w.requires_grad)
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
self.assertIsNone(w.grad_fn)
self.assertFalse(z.requires_grad)
self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
self.assertIsNone(z.grad_fn)

# test nested decorator and with-statement on no_grad
with torch.no_grad():
self.assertFalse(torch.is_grad_enabled())
w = binary_op(x, y)
self.assertFalse(torch.is_grad_enabled())

def test_no_grad_python_function(self):
"""Python Functions should respect grad mode."""
Expand Down
1 change: 1 addition & 0 deletions tools/autograd/gen_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections import defaultdict
from .utils import YamlLoader, split_name_params

# See NOTE [ Autograd View Variables ] in variable.h for details.
VIEW_FUNCTIONS = {
'alias', 'as_strided', 'diagonal', 'expand', 'narrow', 'permute', 'select', 'slice',
'squeeze', 't', 'transpose', 'unfold', 'unsqueeze', 'view', 'unbind',
Expand Down
5 changes: 4 additions & 1 deletion tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,10 @@ def wrap_output(call):
if 'Tensor' not in declaration['return_type']:
return call
elif is_view:
return 'as_view(self, {})'.format(call)
# If `GradMode::is_enabled()` is False, this is a non-differentiable
# view. Gradients should not flow through.
# See NOTE [ Autograd View Variables ] in variable.h for details.
return 'as_view(self, {}, GradMode::is_enabled())'.format(call)
else:
return 'as_variable({})'.format(call)

Expand Down
11 changes: 7 additions & 4 deletions torch/csrc/autograd/VariableTypeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,24 @@ template<typename... Args> inline variable_list flatten_tensor_args(Args&&... ar
return out; // RVO
}

inline Tensor as_view(const Tensor & base, Tensor tensor) {
// See NOTE [ Autograd View Variables ] for details.
inline Tensor as_view(const Tensor & base, Tensor tensor, bool is_differentiable = true) {
auto base_var = Variable(base);
if (base_var.is_view()) {
base_var = base_var.base();
}
return make_variable_view(std::move(base_var), std::move(tensor));
return make_variable_view(std::move(base_var), std::move(tensor), is_differentiable);
}

inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor> tensors) {
// See NOTE [ Autograd View Variables ] for details.
inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor> tensors,
bool is_differentiable = true) {
auto base_var = Variable(base);
if (base_var.is_view()) {
base_var = base_var.base();
}
for(Tensor &tensor : tensors) {
tensor = make_variable_view(base_var, std::move(tensor));
tensor = make_variable_view(base_var, std::move(tensor), is_differentiable);
}
return tensors;
}
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/autograd/functions/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ struct CopyBackwards : public Function {
// Performs grad[idx] = fn(grad[idx]), but out-of-place. The slicing operation
// grad[idx] is defined by the relative sizes, strides, and offset of base and
// view.
// When an in-place operation is done on a differentiable view, the base's
// grad_fn is updated to become a `CopySlice` wrapping the backward of the
// in-place operation.
// See NOTE [ Autograd View Variables ].
struct CopySlices : public Function {
CopySlices(
const Variable& base_var,
Expand Down
16 changes: 5 additions & 11 deletions torch/csrc/autograd/variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,6 @@ std::shared_ptr<Function> Variable::Impl::get_grad_accumulator() {
return result;
}

Variable Variable::Impl::detach() const {
auto detached = make_variable(data_, /*requires_grad=*/false);
detached.set_version_counter(version_counter_);
return detached;
}

void Variable::Impl::detach_() {
if (is_view_) {
AT_ERROR("Can't detach views in-place. Use detach() instead");
Expand Down Expand Up @@ -172,7 +166,7 @@ void Variable::Impl::release_resources() {
hooks_.clear();
}

Variable::ViewImpl::ViewImpl(Variable base, at::Tensor data, Edge gradient_edge)
Variable::DifferentiableViewImpl::DifferentiableViewImpl(Variable base, at::Tensor data, Edge gradient_edge)

This comment was marked as off-topic.

This comment was marked as off-topic.

: Variable::Impl(std::move(data), false, std::move(gradient_edge)),
base_(std::move(base)) {
AT_CHECK(base_.defined(), "base is undefined");
Expand All @@ -184,7 +178,7 @@ Variable::ViewImpl::ViewImpl(Variable base, at::Tensor data, Edge gradient_edge)
attr_version = version_counter_.current_version();
}

std::shared_ptr<Function>& Variable::ViewImpl::get_grad_fn() {
std::shared_ptr<Function>& Variable::DifferentiableViewImpl::get_grad_fn() {
std::lock_guard<std::mutex> lock(mutex_);
if (!grad_fn_ && !base_.requires_grad()) {
return grad_fn_;
Expand All @@ -208,7 +202,7 @@ std::shared_ptr<Function>& Variable::ViewImpl::get_grad_fn() {
return grad_fn_;
}

void Variable::ViewImpl::rebase_history(Edge gradient_edge) {
void Variable::DifferentiableViewImpl::rebase_history(Edge gradient_edge) {
AT_ASSERT(gradient_edge.input_nr == 0);
AT_ASSERT(gradient_edge.function);
AT_CHECK(
Expand All @@ -221,15 +215,15 @@ void Variable::ViewImpl::rebase_history(Edge gradient_edge) {
get_grad_fn(); // trigger an update to the view's grad_fn
}

void Variable::ViewImpl::release_resources() {
void Variable::DifferentiableViewImpl::release_resources() {
Variable::Impl::release_resources();
base_.reset();
}

void Variable::rebase_history(Edge gradient_edge) {
AT_ASSERT(gradient_edge.function != nullptr);
if (is_view()) {
auto& impl = static_cast<Variable::ViewImpl&>(*get());
auto& impl = static_cast<Variable::DifferentiableViewImpl&>(*get());
impl.rebase_history(std::move(gradient_edge));
} else {
set_gradient_edge(std::move(gradient_edge));
Expand Down
105 changes: 91 additions & 14 deletions torch/csrc/autograd/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ struct Function;
/// `Variable`. You can determine whether `Variable` is in fact a view by
/// probing its `is_view()` method. Note that the *view* semantics are only
/// meaningful for `Variable` relations that are relevant to autograd. For
/// example, if you hide your code from autograd using `.data`, the `Variable`s
/// will not be registered as having view relations, even if they share storage.
/// example, if you hide your code from autograd using `.no_grad()`, the
/// `Variable`s will not be registered as having view relations, even if they
/// share storage.
/// See NOTE [ Autograd View Variables ] for more details.
///
///
/// Interface
Expand Down Expand Up @@ -92,9 +94,13 @@ struct TORCH_API Variable : public at::Tensor {

/// Creates a `Variable` that is a *view* of another (*base*) variable.
/// The `gradient_edge` is an optional (gradient_function, input_number) pair.
/// `is_differentiable` is a bool that specifies whether this view is
/// differentiable, i.e., whether the relation should be tracked by autograd.
/// See NOTE [ Autograd View Variables ] for details.
friend Variable make_variable_view(
Variable base,
at::Tensor data,
bool is_differentiable,
Edge gradient_edge);

/// Creates a `Variable` from the given `Tensor`. `requires_grad` should be
Expand Down Expand Up @@ -260,7 +266,7 @@ struct TORCH_API Variable : public at::Tensor {
/// and the `get()` method which exposes it shall forever remain private and
/// never be exposed to the public interface of this class.
struct Impl;
struct ViewImpl;
struct DifferentiableViewImpl;

// Private Methods
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -327,7 +333,6 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
return grad_;
}

Variable detach() const;
void detach_();

void set_data(Tensor new_data);
Expand Down Expand Up @@ -369,15 +374,76 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
};

//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Variable::ViewImpl
// Variable::DifferentiableViewImpl
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// A Variable that is a view on another Variable. The base and view share the
/// same version_counter. The grad_fn field of the Variable may become stale
/// due to in-place modifications of the shared data. Accesses should go
/// through get_grad_fn(). All other fields are always valid.
struct TORCH_API Variable::ViewImpl : public Variable::Impl {
ViewImpl(Variable base, at::Tensor data, Edge gradient_edge);
/// NOTE [ Autograd View Variables ]
///
/// Many operations return Variable that shares storage with an input Variable.
/// The returned Variable is called a **view** Variable on the input **base**
/// Variable.
///
/// In PyTorch, we have two types of views: differentiable views, and
/// non-differentiable views. In either type, to support proper version
/// checking, the base and view Variables must always share the same
/// version_counter.
///
///
/// Differentiable Views
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Differentiable views are the view variables where you want gradients to flow
/// back to the base variables. Out-of-place operations on views are quite
/// straightforward, but in-place ones are very tricky. Even if the base
/// variable may not require grad when we create the view, we still need to
/// track the view relation because future in-place ops may require back-proping
/// through it. For example, we need to support
///
/// (1) in-place operation on view, e.g.,
///
/// # Have:
/// # base.requires_grad = False
/// # var.requires_grad = True
/// base[1] = var # i.e., base[1].copy_(var)
/// torch.autograd.grad(base.sum(), var) <- should return an all ones tensor
///
/// (2) in-place operation on base after view is created, e.g.,
///
/// # Have:
/// # base.requires_grad = False
/// # var.requires_grad = True
/// view = base[1]
/// base.copy_(var)
/// torch.autograd.grad(view.sum(), var) <- should return a tensor with
/// var[1] filled with all ones and
/// zeros everywhere else
///
/// Variable::DifferentiableViewImpl is created to support gradient tracking of
/// such **in-place** operations. In particular,
/// + if an in-place op is done on base, the grad_fn field of the view may
/// become stale. So accesses should always go through get_grad_fn(), which
/// reconstructs an updated grad_fn if the version_counter has incremented.
/// All other fields are always valid.
/// + if an in-place op is done on view, in rebase_history() of view, which is
/// called after every in-place op in VariableType.cpp, the grad_fn of base
/// is updated.
///
///
/// Non-Differentiable Views
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// In certain cases, although function outputs share storage with inputs, they
/// will **never** require gradient history tracking. Instead of registering the
/// view relation via DifferentiableViewImpl in autograd, the views will be
/// using usual Variable::Impl and just share the version counters with the base
/// Variables.
/// Some examples are:
/// 1. Views created from .detach(),
/// 2. Views created when GradMode::is_enabled() = false.
/// These are called non-differentiable views as the gradients do not flow
/// through the view relation.
/// Relevant logic for non-differentiable views is implemented in
/// make_variable_view below, and wrap_output of gen_variable_type.py.
struct TORCH_API Variable::DifferentiableViewImpl : public Variable::Impl {
DifferentiableViewImpl(Variable base, at::Tensor data, Edge gradient_edge);

/// Gets the up-to-date grad_fn. If the shared data or base was modified, we
/// re-create the grad_fn to express the up-to-date view relationship between
Expand Down Expand Up @@ -411,13 +477,24 @@ struct TORCH_API Variable::ViewImpl : public Variable::Impl {
// Factory Functions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// See NOTE [ Autograd View Variables ] for details.
inline Variable make_variable_view(
Variable base,
at::Tensor data,
bool is_differentiable = true,
Edge gradient_edge = Edge()) {
if (data.defined()) {
return Variable(c10::make_intrusive<Variable::ViewImpl>(
std::move(base), std::move(data), std::move(gradient_edge)));
if (is_differentiable) {
/// Differentiable view. Track history with DifferentiableViewImpl.
return Variable(c10::make_intrusive<Variable::DifferentiableViewImpl>(
std::move(base), std::move(data), std::move(gradient_edge)));
} else {
/// Non-differentiable view. Just share version counter.
auto var = Variable(c10::make_intrusive<Variable::Impl>(
std::move(data), false, std::move(gradient_edge)));
var.set_version_counter(base.version_counter());
return var;
}
}
return Variable();
}
Expand Down Expand Up @@ -497,7 +574,7 @@ inline std::shared_ptr<Function> Variable::grad_accumulator() const {
}

inline Variable Variable::detach() const {
return get()->detach();
return make_variable_view(*this, get()->data_, /*is_differentiable=*/false);
}

inline void Variable::detach_() {
Expand Down