Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0188bca

Browse files
CrashCource NN example fixed
1 parent 923f3c1 commit 0188bca

File tree

14 files changed

+47
-152
lines changed

14 files changed

+47
-152
lines changed

examples/BasicExamples/CrashCourse-NDArray.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,20 @@
55
using MxNet.Gluon;
66
using MxNet.Gluon.NN;
77
using MxNet.Initializers;
8+
using MxNet.Numpy;
89

910
namespace BasicExamples
1011
{
1112
public class CrashCourse_NDArray
1213
{
13-
private static NDArray F(NDArray a)
14+
private static ndarray F(ndarray a)
1415
{
15-
NDArray c = null;
16+
ndarray c = null;
1617
var b = a * 2;
17-
while (b.Norm().AsScalar<float>() < 1000)
18+
while (np.linalg.norm(b).asscalar() < 1000)
1819
b = b * 2;
1920

20-
if (b.Sum() >= 0)
21+
if (b.sum().asscalar() >= 0)
2122
c = b[0];
2223
else
2324
c = b[1];
@@ -27,7 +28,6 @@ private static NDArray F(NDArray a)
2728

2829
public static void GetStarted()
2930
{
30-
var ctx = mx.Cpu();
3131
var net = new Sequential();
3232

3333
// Similar to Dense, it is not necessary to specify the input channels
@@ -49,9 +49,9 @@ public static void GetStarted()
4949

5050
net.Initialize();
5151
// Input shape is (batch_size, color_channels, height, width)
52-
var x = nd.Random.Uniform(shape: new Shape(4, 1, 28, 28));
53-
NDArray y = net.Call(x);
54-
Console.WriteLine(y.Shape);
52+
var x = np.random.uniform(size: new Shape(4, 1, 28, 28));
53+
ndarray y = net.Call(x);
54+
Console.WriteLine(y.shape);
5555

5656
Console.WriteLine(net[0].Params["weight"].Data().shape);
5757
Console.WriteLine(net[5].Params["bias"].Data().shape);
@@ -81,7 +81,7 @@ public MixMLP() : base()
8181

8282
public override NDArrayOrSymbolList Forward(NDArrayOrSymbolList args)
8383
{
84-
var y = nd.Relu(this.blk.Call(args)[0]);
84+
var y = npx.relu(this.blk.Call(args)[0]);
8585
Console.WriteLine(y);
8686
return this.dense.Call(y);
8787
}

examples/BasicExamples/CrashCourse-NN.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using MxNet.Gluon.Losses;
1010
using MxNet.Gluon.NN;
1111
using MxNet.Initializers;
12+
using MxNet.Numpy;
1213
using MxNet.Optimizers;
1314

1415
namespace BasicExamples
@@ -62,8 +63,8 @@ public static void Run()
6263

6364
foreach (var (data, label) in train_data)
6465
{
65-
NDArray loss = null;
66-
NDArray output = null;
66+
ndarray loss = null;
67+
ndarray output = null;
6768
// forward + backward
6869
using (Autograd.Record())
6970
{
@@ -77,7 +78,7 @@ public static void Run()
7778
trainer.Step(batch_size);
7879

7980
//calculate training metrics
80-
train_loss += loss.Mean();
81+
train_loss += loss.mean().asscalar();
8182
train_acc += Acc(output, label);
8283
}
8384

@@ -96,11 +97,11 @@ public static void Run()
9697
net.SaveParameters("net.params");
9798
}
9899

99-
public static float Acc(NDArray output, NDArray label)
100+
public static float Acc(ndarray output, ndarray label)
100101
{
101102
// output: (batch, num_output) float32 ndarray
102103
// label: (batch) int32 ndarray
103-
return nd.Equal(output.Argmax(axis: 1), label.AsType(DType.Float32)).Mean();
104+
return np.equal(output.argmax(axis: 1), label.astype(DType.Float32)).mean().asscalar();
104105
}
105106
}
106107
}

examples/BasicExamples/Program.cs

Lines changed: 2 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -13,126 +13,10 @@ class Program
1313
static void Main(string[] args)
1414
{
1515
//Console.WriteLine("Runnin XOR Example......");
16-
//XORGate.Run();
16+
XORGate.Run();
17+
//CrashCourse_NDArray.GetStarted();
1718
//CrashCourse_NN.Run();
1819
//LogisticRegressionExplained.Run();
19-
var methods = mx.GetAllRegisteredOperators();
20-
//var y = np.full(new Shape(3, 3), 0.6);
21-
//var x = np.random.power(y, new Shape(3, 3));
22-
//var fc = npx.fully_connected(x, y, null, 3);
23-
//var z = np.linalg.cholesky(x);
24-
for (int i = 0; i < 3; i++)
25-
{
26-
DateTime start = DateTime.Now;
27-
var x = np.random.uniform(size: new Shape(30000, 10000));
28-
var y = np.random.uniform(size: new Shape(30000, 10000));
29-
var d = 0.5f * np.sqrt(x) + np.sin(y) * np.log(x) - np.exp(y);
30-
//var d = np.dot(x, y);
31-
//npx.waitall();
32-
Console.WriteLine(d.shape);
33-
Console.WriteLine("Duration: " + (DateTime.Now - start).TotalMilliseconds / 1000);
34-
}
35-
}
36-
37-
private static void GenerateFOps()
38-
{
39-
var methods = typeof(F).GetMethods();
40-
StringBuilder fclass = new StringBuilder();
41-
foreach (var method in methods)
42-
{
43-
var parameters = method.GetParameters();
44-
string paramstr = "";
45-
string ndcall = $"nd_np_ops.{method.Name}(";
46-
string symcall = $"sym_np_ops.{method.Name}(";
47-
bool is_symbol = false;
48-
if (parameters.Length > 0)
49-
{
50-
foreach (var item in parameters)
51-
{
52-
if (item.ParameterType.Name == "is_symbol")
53-
{
54-
is_symbol = true;
55-
continue;
56-
}
57-
58-
if (item.ParameterType.Name == "NDArray")
59-
{
60-
paramstr += $"NDArrayOrSymbol {item.Name},";
61-
ndcall += $"{item.Name},";
62-
symcall += $"{item.Name},";
63-
}
64-
else if (item.ParameterType.Name == "NDArrayList")
65-
{
66-
paramstr += $"NDArrayOrSymbolList {item.Name},";
67-
ndcall += $"{item.Name}.NDArrays,";
68-
symcall += $"{item.Name}.Symbols,";
69-
}
70-
else if (item.ParameterType.Name == "Nullable`1")
71-
{
72-
paramstr += $"{item.ParameterType.GenericTypeArguments[0].Name}? {item.Name},";
73-
ndcall += $"{item.Name},";
74-
symcall += $"{item.Name},";
75-
}
76-
else if (item.ParameterType.Name == "Tuple`1")
77-
{
78-
paramstr += $"Tuple<{item.ParameterType.GenericTypeArguments[0].Name}> {item.Name},";
79-
ndcall += $"{item.Name},";
80-
symcall += $"{item.Name},";
81-
}
82-
else
83-
{
84-
paramstr += $"{item.ParameterType.Name} {item.Name},";
85-
ndcall += $"{item.Name},";
86-
symcall += $"{item.Name},";
87-
}
88-
}
89-
90-
paramstr = paramstr.Remove(paramstr.LastIndexOf(','));
91-
ndcall = ndcall.Remove(ndcall.LastIndexOf(',')) + ");";
92-
symcall = symcall.Remove(symcall.LastIndexOf(',')) + ");";
93-
}
94-
95-
string returnType = method.ReturnType.Name;
96-
if (returnType == "NDArray")
97-
returnType = "NDArrayOrSymbol";
98-
else if (returnType == "NDArrayList")
99-
returnType = "NDArrayOrSymbolList";
100-
101-
string methodBody = $"public static {returnType} {method.Name}({paramstr})";
102-
methodBody += "\n{";
103-
104-
string firstNdParam = "";
105-
var ndparam = parameters.FirstOrDefault(x => x.ParameterType.Name == "NDArrayOrSymbol");
106-
if (ndparam == null)
107-
{
108-
ndparam = parameters.FirstOrDefault(x => x.ParameterType.Name == "NDArrayOrSymbolList");
109-
}
110-
111-
if (is_symbol)
112-
{
113-
methodBody += $"if (!is_symbol)";
114-
methodBody += "\n{\n";
115-
methodBody += "return " + ndcall + "\n}\n";
116-
methodBody += "return " + symcall;
117-
methodBody += "\n}";
118-
119-
fclass.AppendLine(methodBody);
120-
}
121-
else if (ndparam != null)
122-
{
123-
firstNdParam = ndparam.ParameterType.Name == "NDArrayList" ? $"{ndparam.Name}[0]" : ndparam.Name;
124-
125-
methodBody += $"if ({firstNdParam}.IsNDArray)";
126-
methodBody += "\n{\n";
127-
methodBody += "return " + ndcall + "\n}\n";
128-
methodBody += "return " + symcall;
129-
methodBody += "\n}";
130-
131-
fclass.AppendLine(methodBody);
132-
}
133-
}
134-
135-
string all = fclass.ToString();
13620
}
13721
}
13822
}

examples/BasicExamples/XORGate.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ public static void Run()
7171

7272
metric.Update(label, outputs.ToArray());
7373
trainer.Step(batch.Data[0].shape[0]);
74+
npx.waitall();
7475
}
7576

7677
var (name, acc) = metric.Get();

src/MxNet/Gluon/Block/HybridBlock.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ private void InterAttrs(string infer_fn, string attr, NDArrayOrSymbolList argume
644644

645645
var collectedValues = CollectParams().Values();
646646
for (var i = 0; i < collectedValues.Length; i++)
647-
collectedValues[i]._shape = sdict[collectedValues[i].Name];
647+
collectedValues[i]._shape = sdict[collectedValues[i]._var_name];
648648
}
649649
else if (infer_fn == "infer_type")
650650
{
@@ -664,7 +664,7 @@ private void InterAttrs(string infer_fn, string attr, NDArrayOrSymbolList argume
664664

665665
var collectedValues = CollectParams().Values();
666666
for (var i = 0; i < collectedValues.Length; i++)
667-
collectedValues[i].DataType = sdict[collectedValues[i].Name];
667+
collectedValues[i].DataType = sdict[collectedValues[i]._var_name];
668668
}
669669
}
670670

src/MxNet/Gluon/Data/Dataloader/DataLoader.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ public static (Func<int, int, Shape, DType, ndarray>, (int, int, Shape, DType))
156156
public static NDArrayList DefaultBatchifyFn(NDArrayList data)
157157
{
158158
var shape = data[0].shape.Data.ToList();
159-
shape[0] = data.Length;
159+
shape[0] = data[0].shape[0];
160160
var x = np.stack(data);
161161
x = x.reshape(shape.ToArray());
162162
return x;

src/MxNet/Gluon/Data/Vision/Datasets/FashionMNIST.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public override void GetData()
7878
stream.Seek(8, SeekOrigin.Begin);
7979
stream.Read(buffer, 0, buffer.Length);
8080

81-
_label = new ndarray(buffer.Select(x => (float)x).ToArray(), new Shape(60000));
81+
_label = np.array(buffer);
8282
}
8383
}
8484

@@ -92,8 +92,12 @@ public override void GetData()
9292
var buffer = new byte[stream.Length - 16];
9393
stream.Seek(16, SeekOrigin.Begin);
9494
stream.Read(buffer, 0, buffer.Length);
95-
var x = np.array(buffer);
96-
_data = new ndarray(buffer.Select(y => (float)y).ToArray(), new Shape(60000, 28, 28, 1)) / 255;
95+
ndarray x = null;
96+
if (_train)
97+
x = np.array(buffer).reshape(60000, 28, 28, 1);
98+
else
99+
x = np.array(buffer).reshape(10000, 28, 28, 1);
100+
_data = x / 255;
97101
}
98102
}
99103
}

src/MxNet/Gluon/Metrics/BinaryAccuracy.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public override void Update(ndarray labels, ndarray preds)
3535
var label = labels.Ravel();
3636
preds = preds.Ravel() > Threshold;
3737

38-
var num_correct = nd.Equal(preds, label).AsType(DType.Float32).Sum();
38+
var num_correct = np.equal(preds, label).astype(DType.Float32).sum().asscalar();
3939

4040
sum_metric += num_correct;
4141
global_sum_metric += num_correct;

src/MxNet/Gluon/Metrics/CrossEntropy.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public override void Update(ndarray labels, ndarray preds)
3636

3737
l = l.ravel();
3838
var p = preds;
39-
var prob = p[np.arange(l.shape[0]), l.Cast(np.Int64)];
39+
var prob = p[np.arange(l.shape[0]), l.astype(np.Int64)];
4040
var cross_entropy = np.sum(-np.log(prob + eps)).AsScalar<float>();
4141
sum_metric += sum_metric;
4242
global_sum_metric += sum_metric;

src/MxNet/Gluon/Metrics/NegativeLogLikelihood.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public override void Update(ndarray labels, ndarray preds)
3838
l = l.ravel();
3939
var p = preds;
4040
var num_examples = p.shape[0];
41-
var prob = p[np.arange(num_examples).Cast(np.Int64), l.Cast(np.Int64)];
41+
var prob = p[np.arange(num_examples).astype(np.Int64), l.astype(np.Int64)];
4242
var nll = (-np.log(prob + eps)).sum().AsScalar<float>();
4343
sum_metric += nll;
4444
global_sum_metric += nll;

0 commit comments

Comments
 (0)