Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Save self.numel() for backward computation instead of self#5747

Merged
soumith merged 1 commit into
pytorch:masterfrom
zou3519:save-numel
Mar 13, 2018
Merged

Save self.numel() for backward computation instead of self#5747
soumith merged 1 commit into
pytorch:masterfrom
zou3519:save-numel

Conversation

@zou3519
Copy link
Copy Markdown
Contributor

@zou3519 zou3519 commented Mar 13, 2018

Fixes #5741

The only operation that really benefits from this right now is tensor.mean().

cc @gchanan @colesbury

Test Plan

python test/test_autograd.py

@gchanan
Copy link
Copy Markdown
Contributor

gchanan commented Mar 13, 2018

looks good, mind showing the generated code?

@zou3519
Copy link
Copy Markdown
Contributor Author

zou3519 commented Mar 13, 2018

MeanBackward1 struct:

struct MeanBackward1 : public TraceableFunction {
  using TraceableFunction::TraceableFunction;
  variable_list apply(const variable_list& grads) override;
  std::string name() override { return "MeanBackward1"; }
  void release_variables() override {

  }

  std::vector<int64_t> self_sizes;
  int64_t self_numel;

};

forward:

Tensor VariableType::mean(const Tensor & self) const {
  profiler::RecordFunction profiler("mean");
  auto& self_ = unpack(self, "self", 0);
  std::shared_ptr<MeanBackward1> grad_fn;
  if (compute_requires_grad( self )) {
    grad_fn = std::make_shared<MeanBackward1>();
    grad_fn->set_next_edges(collect_next_edges( self ));
    grad_fn->self_sizes = self.sizes();
    grad_fn->self_numel = self.numel();
  }
  jit::tracer::PreTraceInfo trace_info;
  if (jit::tracer::isTracing( self )) {
    trace_info = jit::tracer::preRecordTrace( "mean", { self } );

  }
  auto result = as_variable(baseType->mean(self_));
  set_history(result, grad_fn);
  if (trace_info.state != nullptr) {
    jit::tracer::postRecordTrace( trace_info,  { result } );
  }
  return result;
}

backward:

 variable_list MeanBackward1::apply(const variable_list& grads) {
   IndexRangeGenerator gen;
   auto self_ix = gen.range(1);
   variable_list grad_inputs(gen.size());
   auto& grad = grads[0];
   if (should_compute_output({ self_ix })) {
     auto grad_result = grad.expand(self_sizes) / self_numel;
     copy_range(grad_inputs, self_ix, grad_result);
   }
   return grad_inputs;
 }

@soumith soumith merged commit 11444a7 into pytorch:master Mar 13, 2018
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants