From e0ebc998ced2f69ce0a134a57054bb3b40c0f836 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Wed, 19 Apr 2023 16:52:25 +0800 Subject: [PATCH] Fix the error when loading VGG19. --- .../ArgsDefinition/AutoSerializeLayerArgs.cs | 2 +- .../Common/CustomizedShapeJsonConverter.cs | 34 +++++++------------ .../Tensorflow.Binding.csproj | 2 +- .../Training/Saving/SavedModel/loader.cs | 11 ++++-- .../Utils/generic_utils.cs | 6 ++++ .../SaveModel/SequentialModelLoad.cs | 25 +++++++++++++- 6 files changed, 54 insertions(+), 26 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs index 1a97b0135..59dc51b8e 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs @@ -9,7 +9,7 @@ namespace Tensorflow.Keras.ArgsDefinition /// This class has nothing but the attributes different from `LayerArgs`. /// It's used to serialize the model to `tf` format. /// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`, - /// then the Arg definition should inherit `utoSerializeLayerArgs` instead of `LayerArgs`. + /// then the Arg definition should inherit `AutoSerializeLayerArgs` instead of `LayerArgs`. /// public class AutoSerializeLayerArgs: LayerArgs { diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs index 198662afe..722e0a75e 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs @@ -7,6 +7,11 @@ namespace Tensorflow.Keras.Common { + class ShapeInfoFromPython + { + public string class_name { get; set; } + public long?[] items { get; set; } + } public class CustomizedShapeJsonConverter: JsonConverter { public override bool CanConvert(Type objectType) @@ -44,36 +49,23 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer dims[i] = shape.dims[i]; } } - var token = JToken.FromObject(dims); + var token = JToken.FromObject(new ShapeInfoFromPython() + { + class_name = "__tuple__", + items = dims + }); token.WriteTo(writer); } } public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) { - long?[] dims; - try - { - dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; - } - catch (JsonSerializationException ex) - { - if (reader.Value.Equals("class_name")) - { - reader.Read(); - reader.Read(); - reader.Read(); - dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; - } - else - { - throw ex; - } - } - if (dims is null) + var shape_info_from_python = serializer.Deserialize(reader); + if (shape_info_from_python is null) { return null; } + long ?[]dims = shape_info_from_python.items; long[] convertedDims = new long[dims.Length]; for(int i = 0; i < dims.Length; i++) { diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 935e5545a..3a6bcfa13 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -108,7 +108,7 @@ https://tensorflownet.readthedocs.io - + diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index 2eecfabfd..cad32c59d 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -563,7 +563,7 @@ private void _add_object_graph_edges(SavedObject proto, int node_id) return proto.KindCase switch { SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), - SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), + SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, dependencies), SavedObject.KindOneofCase.BareConcreteFunction => _recreate_bare_concrete_function(proto.BareConcreteFunction, dependencies), SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), @@ -626,7 +626,7 @@ private void _add_object_graph_edges(SavedObject proto, int node_id) } private (Function, Action) _recreate_function(SavedFunction proto, - Dictionary, Trackable> dependencies) + IDictionary, Trackable> dependencies) { var fn = function_deserialization.recreate_function(proto, _concrete_functions); foreach (var name in proto.ConcreteFunctions) @@ -644,6 +644,13 @@ private void _add_object_graph_edges(SavedObject proto, int node_id) return (fn, setattr); } + private (Tensor, Action) _get_tensor_from_fn(CapturedTensor proto) + { + var outer_graph = _concrete_functions[proto.ConcreteFunction].func_graph; + var captured_tensor = outer_graph.get_tensor_by_name(proto.Name); + return (captured_tensor, setattr); + } + // TODO: remove this to a common class. public static Action setattr = (x, y, z) => { diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index 1194bebfe..672ac60e1 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -71,6 +71,9 @@ public static Layer deserialize_keras_object(string class_name, JToken config) var args = deserializationGenericMethod.Invoke(config, null); var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); Debug.Assert(layer is Layer); + + // TODO(Rinne): _shared_object_loading_scope().set(shared_object_id, deserialized_obj) + return layer as Layer; } @@ -82,6 +85,9 @@ public static Layer deserialize_keras_object(string class_name, LayerArgs args) return null; } Debug.Assert(layer is Layer); + + // TODO(Rinne): _shared_object_loading_scope().set(shared_object_id, deserialized_obj) + return layer as Layer; } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index 3788e950f..806c4ece8 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -6,13 +6,13 @@ using Tensorflow.Keras.UnitTest.Helpers; using Tensorflow.NumPy; using static Tensorflow.Binding; +using static Tensorflow.KerasApi; namespace TensorFlowNET.Keras.UnitTest.SaveModel; [TestClass] public class SequentialModelLoad { - [Ignore] [TestMethod] public void SimpleModelFromAutoCompile() { @@ -80,4 +80,27 @@ public void ModelWithSelfDefinedModule() model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); } + + [Ignore] + [TestMethod] + public void VGG19() + { + var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19"); + model.summary(); + + var classify_model = keras.Sequential(new System.Collections.Generic.List() + { + model, + keras.layers.Flatten(), + keras.layers.Dense(10), + }); + classify_model.summary(); + + classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var x = np.random.uniform(0, 1, (8, 512, 512, 3)); + var y = np.ones((8)); + + classify_model.fit(x, y, batch_size: 4); + } }