Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e40bbec

Browse files
Suharsh Sivakumartensorflower-gardener
authored andcommitted
When an invalid op is registered, ensure that tf.load_op_library returns
a nice error rather than silently generating invalid python code. Fixes: - On linux, a fatal log that should be triggered in C++ seems to be ignored causing the tf.load_op_library call to fail and exit. We fix this by propagating the status from op registration to raise an exception from python rather than print a fatal log. The reason why the fatal log is being ignored on linux is unclear, my theory is some sort of SWIG issue, but online research provided no results. - Op registrations that fail are still added to the OpDef registration. This is only an issue when the Fatal log is ignored, but we fix this anyways. Change: 127114085
1 parent b2bf45a commit e40bbec

7 files changed

Lines changed: 119 additions & 17 deletions

File tree

tensorflow/core/framework/load_library.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ Status LoadLibrary(const char* library_filename, void** result,
5151
std::unordered_set<string> seen_op_names;
5252
{
5353
mutex_lock lock(mu);
54-
OpRegistry::Global()->ProcessRegistrations();
54+
Status s = OpRegistry::Global()->ProcessRegistrations();
55+
if (!s.ok()) {
56+
return s;
57+
}
5558
TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(
5659
[&op_list, &seen_op_names](const Status& s,
5760
const OpDef& opdef) -> Status {
@@ -62,18 +65,22 @@ Status LoadLibrary(const char* library_filename, void** result,
6265
return Status::OK();
6366
}
6467
}
65-
*op_list.add_op() = opdef;
66-
seen_op_names.insert(opdef.name());
68+
if (s.ok()) {
69+
*op_list.add_op() = opdef;
70+
seen_op_names.insert(opdef.name());
71+
}
6772
return s;
6873
}));
6974
OpRegistry::Global()->DeferRegistrations();
70-
Status s = env->LoadLibrary(library_filename, &lib);
75+
s = env->LoadLibrary(library_filename, &lib);
76+
if (s.ok()) {
77+
s = OpRegistry::Global()->ProcessRegistrations();
78+
}
7179
if (!s.ok()) {
7280
OpRegistry::Global()->ClearDeferredRegistrations();
7381
TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr));
7482
return s;
7583
}
76-
OpRegistry::Global()->ProcessRegistrations();
7784
TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr));
7885
}
7986
string str;

tensorflow/core/framework/op.cc

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Status OpRegistry::LookUp(const string& op_type_name,
6464
bool first_call = false;
6565
{ // Scope for lock.
6666
mutex_lock lock(mu_);
67-
first_call = CallDeferred();
67+
first_call = MustCallDeferred();
6868
res = gtl::FindWithDefault(registry_, op_type_name, nullptr);
6969
// Note: Can't hold mu_ while calling Export() below.
7070
}
@@ -93,7 +93,7 @@ Status OpRegistry::LookUp(const string& op_type_name,
9393

9494
void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) {
9595
mutex_lock lock(mu_);
96-
CallDeferred();
96+
MustCallDeferred();
9797
for (const auto& p : registry_) {
9898
op_defs->push_back(p.second->op_def);
9999
}
@@ -111,7 +111,7 @@ Status OpRegistry::SetWatcher(const Watcher& watcher) {
111111

112112
void OpRegistry::Export(bool include_internal, OpList* ops) const {
113113
mutex_lock lock(mu_);
114-
CallDeferred();
114+
MustCallDeferred();
115115

116116
std::vector<std::pair<string, const OpRegistrationData*>> sorted(
117117
registry_.begin(), registry_.end());
@@ -138,9 +138,9 @@ void OpRegistry::ClearDeferredRegistrations() {
138138
deferred_.clear();
139139
}
140140

141-
void OpRegistry::ProcessRegistrations() const {
141+
Status OpRegistry::ProcessRegistrations() const {
142142
mutex_lock lock(mu_);
143-
CallDeferred();
143+
return CallDeferred();
144144
}
145145

146146
string OpRegistry::DebugString(bool include_internal) const {
@@ -153,7 +153,7 @@ string OpRegistry::DebugString(bool include_internal) const {
153153
return ret;
154154
}
155155

156-
bool OpRegistry::CallDeferred() const {
156+
bool OpRegistry::MustCallDeferred() const {
157157
if (initialized_) return false;
158158
initialized_ = true;
159159
for (int i = 0; i < deferred_.size(); ++i) {
@@ -163,6 +163,19 @@ bool OpRegistry::CallDeferred() const {
163163
return true;
164164
}
165165

166+
Status OpRegistry::CallDeferred() const {
167+
if (initialized_) return Status::OK();
168+
initialized_ = true;
169+
for (int i = 0; i < deferred_.size(); ++i) {
170+
Status s = RegisterAlreadyLocked(deferred_[i]);
171+
if (!s.ok()) {
172+
return s;
173+
}
174+
}
175+
deferred_.clear();
176+
return Status::OK();
177+
}
178+
166179
Status OpRegistry::RegisterAlreadyLocked(
167180
OpRegistrationDataFactory op_data_factory) const {
168181
std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);

tensorflow/core/framework/op.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,9 @@ class OpRegistry : public OpRegistryInterface {
111111

112112
// Process the current list of deferred registrations. Note that calls to
113113
// Export, LookUp and DebugString would also implicitly process the deferred
114-
// registrations.
115-
void ProcessRegistrations() const;
114+
// registrations. Returns the status of the first failed op registration or
115+
// Status::OK() otherwise.
116+
Status ProcessRegistrations() const;
116117

117118
// Defer the registrations until a later call to a function that processes
118119
// deferred registrations are made. Normally, registrations that happen after
@@ -126,8 +127,13 @@ class OpRegistry : public OpRegistryInterface {
126127
private:
127128
// Ensures that all the functions in deferred_ get called, their OpDef's
128129
// registered, and returns with deferred_ empty. Returns true the first
129-
// time it is called.
130-
bool CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
130+
// time it is called. Prints a fatal log if any op registration fails.
131+
bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
132+
133+
// Calls the functions in deferred_ and registers their OpDef's
134+
// It returns the Status of the first failed op registration or Status::OK()
135+
// otherwise.
136+
Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
131137

132138
// Add 'def' to the registry with additional data 'data'. On failure, or if
133139
// there is already an OpDef with that name registered, returns a non-okay

tensorflow/core/framework/op_registration_test.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,16 @@ TEST(OpRegistrationTest, TestBasic) {
4444
TEST(OpRegistrationTest, TestDuplicate) {
4545
std::unique_ptr<OpRegistry> registry(new OpRegistry);
4646
Register("Foo", registry.get());
47-
registry->ProcessRegistrations();
47+
Status s = registry->ProcessRegistrations();
48+
EXPECT_TRUE(s.ok());
4849

4950
registry->SetWatcher([](const Status& s, const OpDef& op_def) -> Status {
5051
EXPECT_TRUE(errors::IsAlreadyExists(s));
5152
return Status::OK();
5253
});
5354
Register("Foo", registry.get());
54-
registry->ProcessRegistrations();
55+
s = registry->ProcessRegistrations();
56+
EXPECT_TRUE(s.ok());
5557
}
5658

5759
} // namespace tensorflow

tensorflow/user_ops/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ py_tests(
4444
data = [":duplicate_op.so"],
4545
)
4646

47+
tf_custom_op_library(
48+
name = "invalid_op.so",
49+
srcs = ["invalid_op.cc"],
50+
)
51+
52+
py_tests(
53+
name = "invalid_op_test",
54+
size = "small",
55+
srcs = ["invalid_op_test.py"],
56+
data = [":invalid_op.so"],
57+
)
58+
4759
filegroup(
4860
name = "all_files",
4961
srcs = glob(

tensorflow/user_ops/invalid_op.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/op.h"
17+
#include "tensorflow/core/framework/op_kernel.h"
18+
19+
namespace tensorflow {
20+
21+
REGISTER_OP("Invalid")
22+
.Attr("invalid attr: int32") // invalid since the name has a space.
23+
.Doc(R"doc(
24+
An op to test that invalid ops do not successfully generate invalid python code.
25+
)doc");
26+
27+
} // namespace tensorflow
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for custom user ops."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import os.path
21+
22+
import tensorflow as tf
23+
24+
25+
class InvalidOpTest(tf.test.TestCase):
26+
27+
def testBasic(self):
28+
library_filename = os.path.join(tf.resource_loader.get_data_files_path(),
29+
'invalid_op.so')
30+
with self.assertRaises(tf.errors.InvalidArgumentError):
31+
tf.load_op_library(library_filename)
32+
33+
34+
if __name__ == '__main__':
35+
tf.test.main()

0 commit comments

Comments
 (0)