Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 4ecb33a

Browse files
committed
Specialize Optional (Tensor) to None when executing graph
In pytorch#18360, we used undefined Tensor (aka AutogradZeroTensor), but this can be errorprone when the type or value is compared to None, e.g. as seen when comined with the (not yet landed) For this to work, we must allow None passed to functions taking Tensor?.
1 parent 7cc7ed1 commit 4ecb33a

5 files changed

Lines changed: 30 additions & 13 deletions

File tree

test/test_jit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4995,11 +4995,13 @@ def fn(x):
49954995
else:
49964996
return 0
49974997

4998-
fn(None)
4998+
res = fn(None)
4999+
self.assertEqual(res, 1)
49995000
g = fn.graph_for(None)
5000-
self.assertEqual(list(g.inputs())[0].type().str(), 'UndefinedTensor')
5001+
self.assertEqual(list(g.inputs())[0].type().kind(), 'NoneType')
50015002
t = torch.ones(1)
5002-
fn(t)
5003+
res = fn(t)
5004+
self.assertEqual(res, 0)
50035005
g = fn.graph_for(t)
50045006
self.assertEqual(list(g.inputs())[0].type().kind(), 'DimensionedTensorType')
50055007

torch/csrc/jit/argument_spec.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,13 @@ struct ArgumentSpec {
156156
if (original->isSubtypeOf(TensorType::get())
157157
|| original->isSubtypeOf(OptionalType::ofTensor())) {
158158
auto& arg = args.at(offset++);
159-
if (!arg.defined())
160-
return AutogradZeroTensorType::get();
159+
if (!arg.defined()) {
160+
if (original->isSubtypeOf(OptionalType::ofTensor())) {
161+
return NoneType::get();
162+
} else {
163+
return AutogradZeroTensorType::get();
164+
}
165+
}
161166
return DimensionedTensorType::create(
162167
arg.type(),
163168
ConvertIntToCPUOrCUDA(arg.device()),

torch/csrc/jit/operator.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,10 +457,20 @@ bool Operator::matches(const Node* node) const {
457457
const MatchTypeReturn matched_type =
458458
matchTypeVariables(formals[i].type(), actuals[i]->type(), type_env);
459459
if (!matched_type.type) {
460+
if (actuals[i]->type() == NoneType::get() &&
461+
formals[i].type()->kind() == TypeKind::OptionalType) {
462+
// when looking for a match, None is actually OK here
463+
continue;
464+
}
460465
return false;
461466
}
462467
TypePtr formal = *matched_type.type;
463468
if (!actuals[i]->type()->isSubtypeOf(formal)) {
469+
if (actuals[i]->type() == NoneType::get() &&
470+
formals[i].type()->kind() == TypeKind::OptionalType) {
471+
// when looking for a match, None is actually OK here
472+
continue;
473+
}
464474
return false;
465475
}
466476
}

torch/csrc/jit/passes/shape_analysis.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -514,14 +514,13 @@ class ShapePropagator {
514514
return;
515515
}
516516
case prim::unchecked_unwrap_optional: {
517-
// we know we cannot have None as input, so we can always pass
518-
// on the type.
519-
if(auto ot = node->input()->type()->cast<OptionalType>()) {
517+
// If we have None as input, we need to leave the output type alone
518+
if(auto ot = node->input()->type()->cast<OptionalType>()) {
520519
node->output()->setType(ot->getElementType());
521-
} else {
522-
node->output()->setType(node->input()->type());
523-
}
524-
return;
520+
} else if (!node->input()->type()->isSubtypeOf(NoneType::get())) {
521+
node->output()->setType(node->input()->type());
522+
}
523+
return;
525524
}
526525
case prim::ConstantChunk: {
527526
Value* tensor = node->input();

torch/csrc/jit/passes/specialize_autogradzero.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ void specializeAutogradZero(Graph& g) {
1616

1717
for (Value* input : g.inputs()) {
1818
const auto& tp = input->type();
19-
if (tp->isSubtypeOf(AutogradZeroTensorType::get())) {
19+
if (tp->isSubtypeOf(AutogradZeroTensorType::get()) ||
20+
tp->isSubtypeOf(NoneType::get())) {
2021
state[input] = State::Zero;
2122
} else if (tp->isSubtypeOf(TensorType::get())) {
2223
state[input] = State::Nonzero;

0 commit comments

Comments
 (0)