Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f2637e7

Browse files
Updating Gluon code to use numpy array instead of MxNet NDAray which is being deprecated
1 parent 0da2dbb commit f2637e7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+675
-357
lines changed

examples/BasicExamples/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class Program
1212
static void Main(string[] args)
1313
{
1414
//Console.WriteLine("Runnin XOR Example......");
15-
//XORGate.Run();
15+
XORGate.Run();
1616
//CrashCourse_NN.Run();
1717
//LogisticRegressionExplained.Run();
1818
//var methods = mx.GetAllRegisteredCApiOperators();

examples/BasicExamples/XORGate.cs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using MxNet.Optimizers;
99
using System;
1010
using System.Linq;
11+
using MxNet.Numpy;
1112

1213
namespace BasicExamples
1314
{
@@ -16,8 +17,8 @@ public class XORGate
1617
public static void Run()
1718
{
1819
// Create
19-
var trainX = new NDArray(new float[] { 0, 0, 0, 1, 1, 0, 1, 1 }).Reshape(4, 2);
20-
var trainY = new NDArray(new float[] { 0, 1, 1, 0 });
20+
var trainX = new ndarray(new float[] { 0, 0, 0, 1, 1, 0, 1, 1 }).reshape(new Shape(4, 2));
21+
var trainY = new ndarray(new float[] { 0, 1, 1, 0 });
2122

2223
var batch_size = 2;
2324
var train_data = new NDArrayIter(trainX, trainY, batch_size);
@@ -43,17 +44,17 @@ public static void Run()
4344
while (!train_data.End())
4445
{
4546
var batch = train_data.Next();
46-
var data = Utils.SplitAndLoad(batch.Data[0], ctxList);
47-
var label = Utils.SplitAndLoad(batch.Label[0], ctxList);
47+
var data = MxNet.Gluon.Utils.SplitAndLoad(batch.Data[0], ctxList);
48+
var label = MxNet.Gluon.Utils.SplitAndLoad(batch.Label[0], ctxList);
4849
NDArrayList outputs = null;
4950
using (var ag = Autograd.Record())
5051
{
5152
outputs = Enumerable.Zip(data, label, (x, y) =>
5253
{
5354
var z = net.Call(x);
54-
NDArray loss = binary_crossentropy.Call(z, y);
55+
ndarray loss = binary_crossentropy.Call(z, y);
5556
loss.Backward();
56-
lossVal += loss.Mean();
57+
lossVal += loss.mean().AsScalar<float>();
5758
return z;
5859
}).ToList();
5960
}

src/MxNet/AMP/Amp.cs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,5 @@ public static NDArrayOrSymbol CastSymbolNDArray(NDArrayOrSymbol s, DType dtype,
1010
{
1111
throw new NotImplementedException();
1212
}
13-
14-
public static (string, Type[]) GetNdFuncToWrap(string name, Type module, Dictionary<string, object> submodule_dict)
15-
{
16-
throw new NotImplementedException();
17-
}
18-
19-
public static (string, Type[]) GetNpFuncToWrap(string name, Type module, Dictionary<string, object> submodule_dict)
20-
{
21-
throw new NotImplementedException();
22-
}
2313
}
2414
}

src/MxNet/F.cs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ public static NDArrayOrSymbol cumsum(NDArrayOrSymbol a, int? axis, DType dtype,
112112
}
113113
return sym_np_ops.cumsum(a, axis, dtype, @out);
114114
}
115-
public static NDArrayOrSymbol reshape(NDArrayOrSymbol a, Shape newshape, bool reverse, String order)
115+
public static NDArrayOrSymbol reshape(NDArrayOrSymbol a, Shape newshape, bool reverse = false, String order = "C")
116116
{
117117
if (a.IsNDArray)
118118
{
@@ -176,7 +176,7 @@ public static NDArrayOrSymbol diagonal(NDArrayOrSymbol a, int offset, int axis1,
176176
}
177177
return sym_np_ops.diagonal(a, offset, axis1, axis2);
178178
}
179-
public static NDArrayOrSymbol sum(NDArrayOrSymbol a, int? axis, DType dtype, NDArrayOrSymbol @out, bool keepdims, float? initial)
179+
public static NDArrayOrSymbol sum(NDArrayOrSymbol a, int? axis = null, DType dtype = null, NDArrayOrSymbol @out = null, bool keepdims = false, float? initial = null)
180180
{
181181
if (a.IsNDArray)
182182
{
@@ -328,7 +328,7 @@ public static NDArrayOrSymbol softmax(NDArrayOrSymbol data, int axis, NDArrayOrS
328328
}
329329
return Sym.Numpy.npx.softmax(data, axis, length, temperature, use_length, dtype);
330330
}
331-
public static NDArrayOrSymbol log_softmax(NDArrayOrSymbol data, int axis, NDArrayOrSymbol length, Double? temperature, bool use_length, DType dtype)
331+
public static NDArrayOrSymbol log_softmax(NDArrayOrSymbol data, int axis, NDArrayOrSymbol length = null, Double? temperature = null, bool use_length = false, DType dtype = null)
332332
{
333333
if (data.IsNDArray)
334334
{
@@ -360,7 +360,7 @@ public static NDArrayOrSymbol pick(NDArrayOrSymbol data, NDArrayOrSymbol index,
360360
}
361361
return Sym.Numpy.npx.pick(data, index, axis, mode, keepdims);
362362
}
363-
public static NDArrayOrSymbol reshape_like(NDArrayOrSymbol lhs, NDArrayOrSymbol rhs, int? lhs_begin, int? lhs_end, int? rhs_begin, int? rhs_end)
363+
public static NDArrayOrSymbol reshape_like(NDArrayOrSymbol lhs, NDArrayOrSymbol rhs, int? lhs_begin = null, int? lhs_end = null, int? rhs_begin = null, int? rhs_end = null)
364364
{
365365
if (lhs.IsNDArray)
366366
{
@@ -563,7 +563,7 @@ public static NDArrayOrSymbol average(NDArrayOrSymbol a, int? axis, NDArrayOrSym
563563
}
564564
return sym_np_ops.average(a, axis, weights, returned, @out);
565565
}
566-
public static NDArrayOrSymbol mean(NDArrayOrSymbol a, int? axis, DType dtype, NDArrayOrSymbol @out, bool keepdims)
566+
public static NDArrayOrSymbol mean(NDArrayOrSymbol a, int? axis = null, DType dtype = null, NDArrayOrSymbol @out = null, bool keepdims = false)
567567
{
568568
if (a.IsNDArray)
569569
{
@@ -1575,5 +1575,15 @@ public static NDArrayOrSymbolList split(NDArrayOrSymbol ary, int[] indices_or_se
15751575
}
15761576
return sym_np_ops.split(ary, indices_or_sections, axis);
15771577
}
1578+
1579+
public static NDArrayOrSymbol norm(NDArrayOrSymbol x, string ord = null, Shape axis = null, bool keepdims = false)
1580+
{
1581+
if (x.IsNDArray)
1582+
{
1583+
return nd_np_ops.linalg.norm(x, ord, axis, keepdims);
1584+
}
1585+
1586+
return sym_np_ops.linalg.norm(x, ord, axis, keepdims);
1587+
}
15781588
}
15791589
}

src/MxNet/Gluon/Block/Block.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ public Block()
5050
_reg_params = new Dictionary<string, Parameter>();
5151
_forward_hooks = new Dictionary<int, Hook>();
5252
_forward_pre_hooks = new Dictionary<int, Hook>();
53+
Params = new ParameterDict();
5354
}
5455

5556
public virtual ParameterDict Params { get; set; }
@@ -591,7 +592,7 @@ internal static string CommonPrefix(string[] names)
591592
return prefix;
592593
}
593594

594-
internal static (DType[], DType[]) InferParamTypes(SymbolList in_params, Symbol out_params, string[] arg_params,
595+
internal static (DType[], DType[]) InferParamTypes(SymbolList in_params, _Symbol out_params, string[] arg_params,
595596
string[] aux_params, DType default_dtype = null)
596597
{
597598
DType[] arg_types = null;

src/MxNet/Gluon/Block/HybridBlock.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ private void BuildCache(NDArrayOrSymbolList args, bool update_graph = true)
371371
for (var i = 0; i < input_names.Length; i++)
372372
{
373373
Parameter param = null;
374-
NDArray param_data;
374+
ndarray param_data;
375375
var name = input_names[i];
376376
(bool, string, Parameter) triple = (false, "", param);
377377
if (data_names.ContainsKey(name))
@@ -405,10 +405,10 @@ private void BuildCache(NDArrayOrSymbolList args, bool update_graph = true)
405405
throw new Exception("A parameter was added to the graph during optimization but it was not added to the parameter dicts.\nPlease check the backend.");
406406
}
407407

408-
param = new Parameter(name, dtype: param_data.DataType);
408+
param = new Parameter(name, dtype: param_data.dtype);
409409
param._var_name = name;
410410
serialization_name = name;
411-
param.LoadInit(param_data, new Context[] { param_data.Context });
411+
param.LoadInit(param_data, new Context[] { param_data.ctx });
412412
}
413413

414414
triple = (false, serialization_name, param);
@@ -511,7 +511,7 @@ internal NDArrayOrSymbolList CallCachedOp(NDArrayOrSymbolList args)
511511
return Regroup(new List<NDArrayOrSymbol[]> { @out.NDArrayOrSymbols }, _out_format).Item1;
512512
}
513513

514-
public void OptimizeFor(NDArray x, string backend = null, bool clear = false, bool partition_if_dynamic = true, bool static_alloc = false,
514+
public void OptimizeFor(ndarray x, string backend = null, bool clear = false, bool partition_if_dynamic = true, bool static_alloc = false,
515515
bool static_shape = false, int inline_limit = 2, int? forward_bulk_size = null, int? backward_bulk_size = null, Dictionary<string, string> backend_opts = null, NDArrayList args = null)
516516
{
517517
this._backend = backend;
@@ -538,7 +538,7 @@ public void OptimizeFor(NDArray x, string backend = null, bool clear = false, bo
538538
var ctx_set = _tup_1.Item3;
539539
if (!has_symbol && !has_ndarray)
540540
{
541-
throw new Exception("In HybridBlock, there must be one NDArray or one Symbol in the input. Please check the type of the args.\n");
541+
throw new Exception("In HybridBlock, there must be one ndarray or one _Symbol in the input. Please check the type of the args.\n");
542542
}
543543
if (ctx_set.Length > 1)
544544
{
@@ -745,14 +745,14 @@ public void InferType(NDArrayOrSymbolList args)
745745
var params_filename = String.Format("%s-%04d.params", path != null ? path : "", epoch);
746746
if (path != null)
747747
{
748-
NDArray.Save(params_filename, arg_dict);
748+
ndarray.Save(params_filename, arg_dict);
749749
return (sym_filename, arg_dict.Count > 0 ? params_filename : null);
750750
}
751751

752752
return ("", "");
753753
}
754754

755-
public (Symbol, NDArrayDict) Export(int epoch = 0, bool remove_amp_cast = true)
755+
public (_Symbol, NDArrayDict) Export(int epoch = 0, bool remove_amp_cast = true)
756756
{
757757
if (_cached_graph == null)
758758
throw new Exception("Please first call block.hybridize() and then run forward with " +
@@ -836,7 +836,7 @@ public override NDArrayOrSymbol Forward(NDArrayOrSymbol x, params NDArrayOrSymbo
836836

837837
if (!has_symbol && !has_ndarray)
838838
{
839-
throw new Exception("In HybridBlock, there must be one NDArray or one Symbol in the input. Please check the type of the args.\n");
839+
throw new Exception("In HybridBlock, there must be one ndarray or one _Symbol in the input. Please check the type of the args.\n");
840840
}
841841

842842
if (has_ndarray)

src/MxNet/Gluon/Constant.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ limitations under the License.
1515
******************************************************************************/
1616
using System;
1717
using MxNet.Initializers;
18+
using MxNet.Numpy;
1819

1920
namespace MxNet.Gluon
2021
{
2122
public class Constant : Parameter
2223
{
23-
public Constant(string name, NDArray value) : base(name, OpGradReq.Null, value.Shape, value.DataType,
24+
public Constant(string name, ndarray value) : base(name, OpGradReq.Null, value.shape, value.dtype,
2425
init: new CInit(value))
2526
{
2627
Value = value;
@@ -37,20 +38,20 @@ public override OpGradReq GradReg
3738
}
3839
}
3940

40-
public NDArray Value { get; set; }
41+
public ndarray Value { get; set; }
4142

4243
public string InitName { get; set; }
4344

4445
public class CInit : Initializer
4546
{
46-
private readonly NDArray _value;
47+
private readonly ndarray _value;
4748

48-
public CInit(NDArray value)
49+
public CInit(ndarray value)
4950
{
5051
_value = value;
5152
}
5253

53-
public override void InitWeight(string name, ref NDArray arr)
54+
public override void InitWeight(string name, ref ndarray arr)
5455
{
5556
_value.CopyTo(arr);
5657
}

src/MxNet/Gluon/Contrib/Data/Vision/Dataloader/DataLoader.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using MxNet.Gluon.NN;
22
using MxNet.Image;
3+
using MxNet.Numpy;
34
using System;
45
using System.Collections.Generic;
56
using System.Text;
@@ -9,14 +10,14 @@ namespace MxNet.Gluon.Contrib.Data.Vision
910
public class DataLoader
1011
{
1112
public static HybridSequential CreateImageAugment(Shape data_shape, int resize= 0, bool rand_crop= false, bool rand_resize= false,
12-
bool rand_mirror= false, NDArray mean= null, NDArray std= null, float brightness= 0, float contrast= 0, float saturation= 0, float hue= 0,
13+
bool rand_mirror= false, ndarray mean= null, ndarray std= null, float brightness= 0, float contrast= 0, float saturation= 0, float hue= 0,
1314
float pca_noise= 0, float rand_gray= 0, int inter_method= 2, DType dtype= null)
1415
{
1516
throw new NotImplementedRelease1Exception();
1617
}
1718

1819
public static Sequential CreateBboxAugment(Shape data_shape, float rand_crop= 0, float rand_pad= 0, float rand_gray= 0,
19-
bool rand_mirror= false, NDArray mean= null, NDArray std= null, float brightness= 0, float contrast= 0,
20+
bool rand_mirror= false, ndarray mean= null, ndarray std= null, float brightness= 0, float contrast= 0,
2021
float saturation= 0, float pca_noise= 0, float hue= 0, int inter_method= 2,
2122
float max_aspect_ratio= 2, (float, float)? area_range= null,
2223
int max_attempts= 50, (int, int, int)? pad_val= null, DType dtype = null)

src/MxNet/Gluon/Contrib/Data/Vision/Dataloader/ImageBboxDataLoader.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
using MxNet.Gluon.Data;
22
using MxNet.Image;
3+
using MxNet.Numpy;
34
using System;
45
using System.Collections;
56
using System.Collections.Generic;
67
using System.Text;
78

89
namespace MxNet.Gluon.Contrib.Data.Vision
910
{
10-
public class ImageBboxDataLoader : IEnumerable<(NDArray, NDArray)>
11+
public class ImageBboxDataLoader : IEnumerable<(ndarray, ndarray)>
1112
{
1213
public int Length => throw new NotImplementedRelease1Exception();
1314

1415
public ImageBboxDataLoader(int batch_size, Shape data_shape, string path_imgrec= null, string path_imglist= null, string path_root= ".",
1516
int part_index= 0, int num_parts= 1, Augmenter[] aug_list= null, List<(string, float[])> imglist= null, bool coord_normalized = false,
1617
DType dtype= null, bool shuffle= false, Sampler<int> sampler= null, string last_batch= null, BatchSampler batch_sampler= null,
17-
Func<(NDArray, NDArray)[], (NDArray, NDArray)> batchify_fn = null, int num_workers= 0, bool pin_memory= false,
18+
Func<(ndarray, ndarray)[], (ndarray, ndarray)> batchify_fn = null, int num_workers= 0, bool pin_memory= false,
1819
int pin_device_id= 0, int? prefetch= null, bool thread_pool= false, int timeout= 120)
1920
{
2021
throw new NotImplementedRelease1Exception();
2122
}
2223

23-
public IEnumerator<(NDArray, NDArray)> GetEnumerator()
24+
public IEnumerator<(ndarray, ndarray)> GetEnumerator()
2425
{
2526
throw new NotImplementedRelease1Exception();
2627
}

src/MxNet/Gluon/Contrib/Data/Vision/Dataloader/ImageDataLoader.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
using MxNet.Gluon.Data;
22
using MxNet.Image;
3+
using MxNet.Numpy;
34
using System;
45
using System.Collections;
56
using System.Collections.Generic;
67
using System.Text;
78

89
namespace MxNet.Gluon.Contrib.Data.Vision
910
{
10-
public class ImageDataLoader : IEnumerable<(NDArray, NDArray)>
11+
public class ImageDataLoader : IEnumerable<(ndarray, ndarray)>
1112
{
1213
public int Length => throw new NotImplementedRelease1Exception();
1314

1415
public ImageDataLoader(int batch_size, Shape data_shape, string path_imgrec= null, string path_imglist= null, string path_root= ".",
1516
int part_index= 0, int num_parts= 1, Augmenter[] aug_list= null, List<(string, float[])> imglist= null,
1617
DType dtype= null, bool shuffle= false, Sampler<int> sampler= null, string last_batch= null, BatchSampler batch_sampler= null,
17-
Func<(NDArray, NDArray)[], (NDArray, NDArray)> batchify_fn = null, int num_workers= 0, bool pin_memory= false,
18+
Func<(ndarray, ndarray)[], (ndarray, ndarray)> batchify_fn = null, int num_workers= 0, bool pin_memory= false,
1819
int pin_device_id= 0, int? prefetch= null, bool thread_pool= false, int timeout= 120)
1920
{
2021
throw new NotImplementedRelease1Exception();
2122
}
2223

23-
public IEnumerator<(NDArray, NDArray)> GetEnumerator()
24+
public IEnumerator<(ndarray, ndarray)> GetEnumerator()
2425
{
2526
throw new NotImplementedRelease1Exception();
2627
}

0 commit comments

Comments
 (0)