File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -4995,11 +4995,13 @@ def fn(x):
49954995 else :
49964996 return 0
49974997
4998- fn (None )
4998+ res = fn (None )
4999+ self .assertEqual (res , 1 )
49995000 g = fn .graph_for (None )
5000- self .assertEqual (list (g .inputs ())[0 ].type ().str (), 'UndefinedTensor ' )
5001+ self .assertEqual (list (g .inputs ())[0 ].type ().kind (), 'NoneType ' )
50015002 t = torch .ones (1 )
5002- fn (t )
5003+ res = fn (t )
5004+ self .assertEqual (res , 0 )
50035005 g = fn .graph_for (t )
50045006 self .assertEqual (list (g .inputs ())[0 ].type ().kind (), 'DimensionedTensorType' )
50055007
Original file line number Diff line number Diff line change @@ -156,8 +156,13 @@ struct ArgumentSpec {
156156 if (original->isSubtypeOf (TensorType::get ())
157157 || original->isSubtypeOf (OptionalType::ofTensor ())) {
158158 auto & arg = args.at (offset++);
159- if (!arg.defined ())
160- return AutogradZeroTensorType::get ();
159+ if (!arg.defined ()) {
160+ if (original->isSubtypeOf (OptionalType::ofTensor ())) {
161+ return NoneType::get ();
162+ } else {
163+ return AutogradZeroTensorType::get ();
164+ }
165+ }
161166 return DimensionedTensorType::create (
162167 arg.type (),
163168 ConvertIntToCPUOrCUDA (arg.device ()),
Original file line number Diff line number Diff line change @@ -457,10 +457,20 @@ bool Operator::matches(const Node* node) const {
457457 const MatchTypeReturn matched_type =
458458 matchTypeVariables (formals[i].type (), actuals[i]->type (), type_env);
459459 if (!matched_type.type ) {
460+ if (actuals[i]->type () == NoneType::get () &&
461+ formals[i].type ()->kind () == TypeKind::OptionalType) {
462+ // when looking for a match, None is actually OK here
463+ continue ;
464+ }
460465 return false ;
461466 }
462467 TypePtr formal = *matched_type.type ;
463468 if (!actuals[i]->type ()->isSubtypeOf (formal)) {
469+ if (actuals[i]->type () == NoneType::get () &&
470+ formals[i].type ()->kind () == TypeKind::OptionalType) {
471+ // when looking for a match, None is actually OK here
472+ continue ;
473+ }
464474 return false ;
465475 }
466476 }
Original file line number Diff line number Diff line change @@ -514,14 +514,13 @@ class ShapePropagator {
514514 return ;
515515 }
516516 case prim::unchecked_unwrap_optional: {
517- // we know we cannot have None as input, so we can always pass
518- // on the type.
519- if (auto ot = node->input ()->type ()->cast <OptionalType>()) {
517+ // If we have None as input, we need to leave the output type alone
518+ if (auto ot = node->input ()->type ()->cast <OptionalType>()) {
520519 node->output ()->setType (ot->getElementType ());
521- } else {
522- node->output ()->setType (node->input ()->type ());
523- }
524- return ;
520+ } else if (!node-> input ()-> type ()-> isSubtypeOf ( NoneType::get ())) {
521+ node->output ()->setType (node->input ()->type ());
522+ }
523+ return ;
525524 }
526525 case prim::ConstantChunk: {
527526 Value* tensor = node->input ();
Original file line number Diff line number Diff line change @@ -16,7 +16,8 @@ void specializeAutogradZero(Graph& g) {
1616
1717 for (Value* input : g.inputs ()) {
1818 const auto & tp = input->type ();
19- if (tp->isSubtypeOf (AutogradZeroTensorType::get ())) {
19+ if (tp->isSubtypeOf (AutogradZeroTensorType::get ()) ||
20+ tp->isSubtypeOf (NoneType::get ())) {
2021 state[input] = State::Zero;
2122 } else if (tp->isSubtypeOf (TensorType::get ())) {
2223 state[input] = State::Nonzero;
You can’t perform that action at this time.
0 commit comments