-
Notifications
You must be signed in to change notification settings - Fork 1.4k
modified tensorflow parser from graph branch #522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
wangyida
wants to merge
1
commit into
tiny-dnn:master
Choose a base branch
from
wangyida:tensorflow_parser
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
protoc -I=.. --cpp_out=.. ../tensorflow/core/framework/*.proto |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#include <iostream> | ||
#include <fstream> | ||
#include <sstream> | ||
#include <string> | ||
#include <map> | ||
#include "tiny_dnn/io/tensorflow/proto_parser.h" | ||
using namespace std; | ||
using namespace tiny_cnn; | ||
using namespace tensorflow; | ||
|
||
// Main function: Reads the graph definition from a file and prints all | ||
// the information inside. | ||
int main(int argc, char* argv[]) { | ||
// Verify that the version of the library that we linked against is | ||
// compatible with the version of the headers we compiled against. | ||
GOOGLE_PROTOBUF_VERIFY_VERSION; | ||
|
||
if (argc != 2) { | ||
cerr << "Usage: " << argv[0] << " GRAPH_DEF_FILE" << endl; | ||
return -1; | ||
} | ||
|
||
tensorflow::GraphDef graph_def; | ||
|
||
{ | ||
// Read the existing graph. | ||
fstream input(argv[1], ios::in | ios::binary); | ||
if (!graph_def.ParseFromIstream(&input)) { | ||
cerr << "Failed to parse graph." << endl; | ||
return -1; | ||
} | ||
} | ||
|
||
list_nodes(graph_def); | ||
|
||
// Optional: Delete all global objects allocated by libprotobuf. | ||
google::protobuf::ShutdownProtobufLibrary(); | ||
|
||
return 0; | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
syntax = "proto3"; | ||
|
||
package tensorflow; | ||
option cc_enable_arenas = true; | ||
option java_outer_classname = "AllocationDescriptionProtos"; | ||
option java_multiple_files = true; | ||
option java_package = "org.tensorflow.framework"; | ||
|
||
message AllocationDescription { | ||
// Total number of bytes requested | ||
int64 requested_bytes = 1; | ||
|
||
// Total number of bytes allocated if known | ||
int64 allocated_bytes = 2; | ||
|
||
// Name of the allocator used | ||
string allocator_name = 3; | ||
|
||
// Identifier of the allocated buffer if known | ||
int64 allocation_id = 4; | ||
|
||
// Set if this tensor only has one remaining reference | ||
bool has_single_reference = 5; | ||
|
||
// Address of the allocation. | ||
uint64 ptr = 6; | ||
}; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
syntax = "proto3"; | ||
|
||
package tensorflow; | ||
option cc_enable_arenas = true; | ||
option java_outer_classname = "AttrValueProtos"; | ||
option java_multiple_files = true; | ||
option java_package = "org.tensorflow.framework"; | ||
|
||
import "tensorflow/core/framework/tensor.proto"; | ||
import "tensorflow/core/framework/tensor_shape.proto"; | ||
import "tensorflow/core/framework/types.proto"; | ||
|
||
// Protocol buffer representing the value for an attr used to configure an Op. | ||
// Comment indicates the corresponding attr type. Only the field matching the | ||
// attr type may be filled. | ||
message AttrValue { | ||
message ListValue { | ||
repeated bytes s = 2; // "list(string)" | ||
repeated int64 i = 3 [packed = true]; // "list(int)" | ||
repeated float f = 4 [packed = true]; // "list(float)" | ||
repeated bool b = 5 [packed = true]; // "list(bool)" | ||
repeated DataType type = 6 [packed = true]; // "list(type)" | ||
repeated TensorShapeProto shape = 7; // "list(shape)" | ||
repeated TensorProto tensor = 8; // "list(tensor)" | ||
// TODO(zhifengc/josh11b): implements list(func) if needed. | ||
} | ||
|
||
oneof value { | ||
bytes s = 2; // "string" | ||
int64 i = 3; // "int" | ||
float f = 4; // "float" | ||
bool b = 5; // "bool" | ||
DataType type = 6; // "type" | ||
TensorShapeProto shape = 7; // "shape" | ||
TensorProto tensor = 8; // "tensor" | ||
ListValue list = 1; // any "list(...)" | ||
|
||
// "func" represents a function. func.name is a function's name or | ||
// a primitive op's name. func.attr.first is the name of an attr | ||
// defined for that function. func.attr.second is the value for | ||
// that attr in the instantiation. | ||
NameAttrList func = 10; | ||
|
||
// This is a placeholder only used in nodes defined inside a | ||
// function. It indicates the attr value will be supplied when | ||
// the function is instantiated. For example, let us suppose a | ||
// node "N" in function "FN". "N" has an attr "A" with value | ||
// placeholder = "foo". When FN is instantiated with attr "foo" | ||
// set to "bar", the instantiated node N's attr A will have been | ||
// given the value "bar". | ||
string placeholder = 9; | ||
} | ||
} | ||
|
||
// A list of attr names and their values. The whole list is attached | ||
// with a string name. E.g., MatMul[T=float]. | ||
message NameAttrList { | ||
string name = 1; | ||
map<string, AttrValue> attr = 2; | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
syntax = "proto3"; | ||
|
||
package tensorflow; | ||
option cc_enable_arenas = true; | ||
option java_outer_classname = "CostGraphProtos"; | ||
option java_multiple_files = true; | ||
option java_package = "org.tensorflow.framework"; | ||
|
||
message CostGraphDef { | ||
message Node { | ||
// The name of the node. | ||
string name = 1; | ||
|
||
// The device of the node. | ||
string device = 2; | ||
|
||
// The id of the node. | ||
int32 id = 3; | ||
|
||
// Inputs of this node. They must be executed before this node can be | ||
// executed. An input is a particular output of another node, specified | ||
// by the node id and the output index. | ||
message InputInfo { | ||
int32 preceding_node = 1; | ||
int32 preceding_port = 2; | ||
} | ||
repeated InputInfo input_info = 4; | ||
|
||
// Outputs of this node. | ||
message OutputInfo { | ||
int64 size = 1; | ||
// If >= 0, the output is an alias of an input. Note that an alias input | ||
// may itself be an alias. The algorithm will therefore need to follow | ||
// those pointers. | ||
int64 alias_input_port = 2; | ||
} | ||
repeated OutputInfo output_info = 5; | ||
|
||
// Temporary memory used by this node. | ||
int64 temporary_memory_size = 6; | ||
|
||
// If true, the output is permanent: it can't be discarded, because this | ||
// node is part of the "final output". Nodes may depend on final nodes. | ||
bool is_final = 7; | ||
|
||
// Ids of the control inputs for this node. | ||
repeated int32 control_input = 8; | ||
} | ||
repeated Node node = 1; | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
syntax = "proto3"; | ||
|
||
package tensorflow; | ||
option cc_enable_arenas = true; | ||
option java_outer_classname = "DeviceAttributesProtos"; | ||
option java_multiple_files = true; | ||
option java_package = "org.tensorflow.framework"; | ||
|
||
// BusAdjacency identifies the ability of a device to participate in | ||
// maximally efficient DMA operations within the local context of a | ||
// process. | ||
// | ||
// This is currently ignored. | ||
enum BusAdjacency { | ||
BUS_0 = 0; | ||
BUS_1 = 1; | ||
BUS_ANY = 2; | ||
BUS_NUM_ADJACENCIES = 3; | ||
}; | ||
|
||
message DeviceAttributes { | ||
string name = 1; | ||
|
||
// String representation of device_type. | ||
string device_type = 2; | ||
|
||
// Memory capacity of device in bytes. | ||
int64 memory_limit = 4; | ||
|
||
BusAdjacency bus_adjacency = 5; | ||
|
||
// A device is assigned a global unique number each time it is | ||
// initialized. "incarnation" should never be 0. | ||
fixed64 incarnation = 6; | ||
|
||
// String representation of the physical device that this device maps to. | ||
string physical_device_desc = 7; | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
syntax = "proto3"; | ||
|
||
package tensorflow; | ||
option cc_enable_arenas = true; | ||
option java_outer_classname = "FunctionProtos"; | ||
option java_multiple_files = true; | ||
option java_package = "org.tensorflow.framework"; | ||
|
||
import "tensorflow/core/framework/attr_value.proto"; | ||
import "tensorflow/core/framework/op_def.proto"; | ||
|
||
// A library is a set of named functions. | ||
message FunctionDefLibrary { | ||
repeated FunctionDef function = 1; | ||
repeated GradientDef gradient = 2; | ||
} | ||
|
||
// A function can be instantiated when the runtime can bind every attr | ||
// with a value. When a GraphDef has a call to a function, it must | ||
// have binding for every attr defined in the signature. | ||
// | ||
// TODO(zhifengc): | ||
// * device spec, etc. | ||
message FunctionDef { | ||
// The definition of the function's name, arguments, return values, | ||
// attrs etc. | ||
OpDef signature = 1; | ||
|
||
// The body of the function. | ||
repeated Node node = 2; // function.node.ret[*] are unique. | ||
|
||
// A node is a multi-value assignment: | ||
// (ret[0], ret[1], ...) = func(arg[0], arg[1], ...) | ||
// | ||
// By convention, "func" is resolved by consulting with a user-defined | ||
// library first. If not resolved, "func" is assumed to be a builtin op. | ||
message Node { | ||
// This node produces multiple outputs. They are named ret[0], | ||
// ret[1], ..., etc. | ||
// | ||
// REQUIRES: function.node.ret[*] are unique across all nodes. | ||
// REQUIRES: ret.size == func/op def's number of output args. | ||
repeated string ret = 1; | ||
|
||
// The op/function name. | ||
string op = 2; | ||
|
||
// Arguments passed to this func/op. | ||
// | ||
// arg[i] must be either one of | ||
// function.signature.input_args[*].name or one of | ||
// function.node[*].ret[*]. | ||
// | ||
// REQUIRES: arg.size == func/op def's number of input args. | ||
repeated string arg = 3; | ||
|
||
// Control dependencies. | ||
// | ||
// dep[i] must be one of function.node[*].ret[*] or one of | ||
// function.signature.input_args[*].name. | ||
repeated string dep = 4; | ||
|
||
// Attrs. | ||
// | ||
// 'attr' maps names defined by 'func's attr defs to attr values. | ||
// attr values may have placeholders which are substituted | ||
// recursively by concrete values when this node is instantiated. | ||
// These placeholders must name an attr listed in the FunctionDef's | ||
// signature. | ||
map<string, AttrValue> attr = 5; | ||
} | ||
} | ||
|
||
// GradientDef defines the gradient function of a function defined in | ||
// a function library. | ||
// | ||
// A gradient function g (specified by gradient_func) for a function f | ||
// (specified by function_name) must follow the following: | ||
// | ||
// The function 'f' must be a numerical function which takes N inputs | ||
// and produces M outputs. Its gradient function 'g', which is a | ||
// function taking N + M inputs and produces N outputs. | ||
// | ||
// I.e. if we have | ||
// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), | ||
// then, g is | ||
// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, | ||
// dL/dy1, dL/dy2, ..., dL/dy_M), | ||
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the | ||
// loss function). dL/dx_i is the partial derivative of L with respect | ||
// to x_i. | ||
message GradientDef { | ||
string function_name = 1; // The function name. | ||
string gradient_func = 2; // The gradient function's name. | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wangyida this files should go to
tiny_dnn/io/tensorflow