Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c33801a

Browse files
PixelShuffle, Contact and Sync Norm implemented
1 parent ef182a5 commit c33801a

File tree

13 files changed

+307
-109
lines changed

13 files changed

+307
-109
lines changed

examples/BasicExamples/Program.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ class Program
1212
static void Main(string[] args)
1313
{
1414
//Console.WriteLine("Runnin XOR Example......");
15-
XORGate.Run();
15+
//XORGate.Run();
1616
//CrashCourse_NN.Run();
1717
//LogisticRegressionExplained.Run();
18-
//var methods = mx.GetAllRegisteredCApiOperators();
18+
var methods = mx.GetAllRegisteredOperators();
1919
var y = np.full(new Shape(3, 3), 0.6);
2020
var x = np.random.power(y, new Shape(3, 3));
2121

src/MxNet/F.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,10 @@ public static NDArrayOrSymbol batch_norm(NDArrayOrSymbol x, NDArrayOrSymbol gamm
208208
}
209209
return Sym.Numpy.npx.batch_norm(x, gamma, beta, running_mean, running_var, eps, momentum, fix_gamma, use_global_stats, output_mean_var, axis, cudnn_off, min_calib_range, max_calib_range);
210210
}
211-
public static NDArrayOrSymbol convolution(NDArrayOrSymbol data, NDArrayOrSymbol weight, NDArrayOrSymbol bias, int[] kernel, int[] stride, int[] dilate, int[] pad, int num_filter, int num_group, int workspace, bool no_bias, String cudnn_tune, bool cudnn_off, String layout)
211+
public static NDArrayOrSymbol convolution(NDArrayOrSymbol data, NDArrayOrSymbol weight, NDArrayOrSymbol bias = null,
212+
int[] kernel = null, int[] stride = null, int[] dilate = null, int[] pad = null, int num_filter = 1,
213+
int num_group = 1, int workspace = 1024, bool no_bias = false, string cudnn_tune = null,
214+
bool cudnn_off = false, string layout = null)
212215
{
213216
if (data.IsNDArray)
214217
{
@@ -1542,7 +1545,7 @@ public static NDArrayOrSymbol trace(NDArrayOrSymbol a, int offset, int axis1, in
15421545
}
15431546
return sym_np_ops.trace(a, offset, axis1, axis2, @out);
15441547
}
1545-
public static NDArrayOrSymbol transpose(NDArrayOrSymbol a, int[] axes)
1548+
public static NDArrayOrSymbol transpose(NDArrayOrSymbol a, params int[] axes)
15461549
{
15471550
if (a.IsNDArray)
15481551
{

src/MxNet/Gluon/Metrics/EvalMetric.cs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License.
1616
using MxNet.Numpy;
1717
using System;
1818
using System.Collections.Generic;
19+
using System.Linq;
20+
using System.Reflection;
1921

2022
namespace MxNet.Gluon.Metrics
2123
{
@@ -73,9 +75,13 @@ public virtual void Update(NDArrayList labels, NDArrayList preds)
7375
for (var i = 0; i < labels.Length; i++) Update(labels[i], preds[i]);
7476
}
7577

76-
public void UpdateDict(NDArrayDict label, NDArrayDict pred)
78+
public void UpdateDict(NDArrayDict labels, NDArrayDict preds)
7779
{
78-
throw new NotImplementedException();
80+
if (labels.Count != preds.Count) throw new ArgumentException("Labels and Predictions are unequal length");
81+
for (int i = 0; i < labels.Count; i++)
82+
{
83+
Update(labels[labels.Keys[i]], preds[preds.Keys[i]]);
84+
}
7985
}
8086

8187
public virtual void Reset()
@@ -127,7 +133,20 @@ public Dictionary<string, float> GetGlobalNameValue()
127133

128134
public static implicit operator EvalMetric(string name)
129135
{
130-
throw new NotImplementedException();
136+
var assembly = Assembly.GetAssembly(Type.GetType($"MxNet.Gluon.Metrics.EvalMetric"));
137+
var types = assembly.GetTypes().Where(t => String.Equals(t.Namespace, "MxNet.Gluon.Metrics", StringComparison.Ordinal)).ToList();
138+
139+
foreach (var item in types)
140+
{
141+
var obj = Activator.CreateInstance(item);
142+
var evalName = item.GetProperty("Name").GetValue(obj);
143+
if(evalName != null && evalName.ToString().ToLower() == name.ToLower())
144+
{
145+
return (EvalMetric)obj;
146+
}
147+
}
148+
149+
throw new Exception($"Metric with name '{name}' not found.");
131150
}
132151
}
133152
}

src/MxNet/Gluon/NN/BaseLayers/Concatenate.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,21 @@ namespace MxNet.Gluon.NN
66
{
77
public class Concatenate : Sequential
88
{
9-
public int Axis { get; set; }
9+
public int axis { get; set; }
1010
public Concatenate(int axis = -1)
1111
{
12-
Axis = axis;
12+
this.axis = axis;
1313
}
1414

1515
public override NDArrayOrSymbol Forward(NDArrayOrSymbol input, params NDArrayOrSymbol[] args)
1616
{
17-
throw new NotImplementedException();
17+
var @out = new NDArrayOrSymbolList();
18+
foreach (var block in this._childrens.Values)
19+
{
20+
@out.Add(block.Call(input));
21+
}
22+
23+
return F.concatenate(@out, axis: this.axis);
1824
}
1925
}
2026
}

src/MxNet/Gluon/NN/BaseLayers/HybridConcatenate.cs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,32 @@ namespace MxNet.Gluon.NN
66
{
77
public class HybridConcatenate : HybridSequential
88
{
9-
public int Axis { get; set; }
9+
public int axis { get; set; }
1010
public HybridConcatenate(int axis = -1)
1111
{
12-
Axis = axis;
12+
this.axis = axis;
1313
}
1414

1515
public override NDArrayOrSymbol Forward(NDArrayOrSymbol x, params NDArrayOrSymbol[] args)
1616
{
17-
throw new NotImplementedException();
17+
var @out = new NDArrayOrSymbolList();
18+
foreach (var block in this._childrens.Values)
19+
{
20+
@out.Add(block.Call(x));
21+
}
22+
23+
return F.concatenate(@out, axis: this.axis);
1824
}
1925

2026
public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol x, params NDArrayOrSymbol[] args)
2127
{
22-
throw new NotImplementedException();
28+
var @out = new NDArrayOrSymbolList();
29+
foreach (var block in this._childrens.Values)
30+
{
31+
@out.Add(block.Call(x));
32+
}
33+
34+
return F.concatenate(@out, axis: this.axis);
2335
}
2436
}
2537
}

src/MxNet/Gluon/NN/BaseLayers/SyncBatchNorm.cs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,39 @@ namespace MxNet.Gluon.NN
2020
{
2121
public class SyncBatchNorm : _BatchNorm
2222
{
23+
public int? num_devices;
24+
2325
public SyncBatchNorm(int in_channels = 0, int? num_devices = null, float momentum = 0.9f, float epsilon = 1e-5f,
2426
bool center = true, bool scale = true, bool use_global_stats = false, string beta_initializer = "zeros",
2527
string gamma_initializer = "ones", string running_mean_initializer = "zeros",
26-
string running_variance_initializer = "ones",
27-
string prefix = "", ParameterDict @params = null)
28+
string running_variance_initializer = "ones")
2829
: base(1, momentum, epsilon, center, scale, false, use_global_stats, beta_initializer, gamma_initializer,
2930
running_mean_initializer, running_variance_initializer, in_channels)
3031
{
31-
throw new NotImplementedException();
32+
this.num_devices = num_devices;
3233
}
3334

3435
internal int GetNumDevices()
3536
{
36-
throw new NotImplementedException();
37+
Logger.Warning("Caution using SyncBatchNorm: if not using all the GPUs, please mannually set num_devices");
38+
var num_devices = MxUtil.GetGPUCount();
39+
num_devices = num_devices > 0 ? num_devices : 1;
40+
return num_devices;
3741
}
3842

3943
public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol x, params NDArrayOrSymbol[] args)
4044
{
41-
throw new NotImplementedException();
45+
var gamma = args.Length > 0 ? args[0] : null;
46+
var beta = args.Length > 1 ? args[1] : null;
47+
var running_mean = args.Length > 2 ? args[2] : null;
48+
var running_var = args.Length > 3 ? args[3] : null;
49+
50+
if (x.IsNDArray)
51+
return nd.Contrib.SyncBatchNorm(x, gamma, beta, running_mean, running_var, "", Epsilon, Momentum, FixGamma,
52+
Use_Global_Stats, false, num_devices.HasValue ? num_devices.Value : 1);
53+
54+
return sym.Contrib.SyncBatchNorm(x, gamma, beta, running_mean, running_var, "", Epsilon, Momentum, FixGamma,
55+
Use_Global_Stats, false, num_devices.HasValue ? num_devices.Value : 1, "fwd");
4256
}
4357
}
4458
}
Lines changed: 138 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,113 @@
11
using MxNet.Initializers;
2+
using MxNet.Sym.Numpy;
23
using System;
34
using System.Collections.Generic;
5+
using System.Diagnostics;
6+
using System.Linq;
47
using System.Text;
58

69
namespace MxNet.Gluon.NN
710
{
811
public class DeformableConvolution : HybridBlock
912
{
13+
public int _channels;
14+
15+
public int _in_channels;
16+
17+
public Activation act;
18+
19+
public Parameter deformable_conv_bias;
20+
21+
public Parameter deformable_conv_weight;
22+
23+
public Parameter offset_bias;
24+
25+
public Parameter offset_weight;
26+
27+
public int offset_channels;
28+
29+
public int[] kernel;
30+
31+
public int[] strides;
32+
33+
public int[] padding;
34+
35+
public int[] dilation;
36+
37+
public int num_filter;
38+
39+
public int num_group;
40+
41+
public bool no_bias;
42+
43+
public string layout;
44+
45+
public int num_deformable_group;
46+
47+
public int[] adj;
48+
1049
public DeformableConvolution(int channels, (int, int)? kernel_size = null, (int, int)? strides = null, (int, int)? padding = null, (int, int)? dilation = null,
1150
int groups = 1, int num_deformable_group = 1, string layout = "NCHW", bool use_bias = true, int in_channels = 0, ActivationType? activation = null,
12-
Initializer weight_initializer = null, string bias_initializer = "zeros", bool offset_use_bias = true, int[] adj = null,
13-
string op_name = "DeformableConvolution")
51+
Initializer weight_initializer = null, string bias_initializer = "zeros", string offset_weight_initializer = "zeros",
52+
string offset_bias_initializer = "zeros", bool offset_use_bias = true, int[] adj = null, string op_name = "DeformableConvolution")
1453
{
15-
throw new NotImplementedException();
54+
this._channels = channels;
55+
this._in_channels = in_channels;
56+
Debug.Assert(new string[] { "NCHW", "NHWC" }.Contains(layout), "Only supports 'NCHW' and 'NHWC' layout for now");
57+
var offset_channels = 2 * kernel_size.Value.Item1 * kernel_size.Value.Item2 * num_deformable_group;
58+
59+
this.kernel = kernel_size.HasValue ? new int[] { kernel_size.Value.Item1, kernel_size.Value.Item2 } : new int[] { 1, 1 };
60+
this.strides = strides.HasValue ? new int[] { strides.Value.Item1, strides.Value.Item2 } : new int[] { 1, 1 };
61+
this.padding = padding.HasValue ? new int[] { padding.Value.Item1, padding.Value.Item2 } : new int[] { 0, 0 };
62+
this.dilation = dilation.HasValue ? new int[] { dilation.Value.Item1, dilation.Value.Item2 } : new int[] { 0, 0 };
63+
this.num_filter = offset_channels;
64+
this.num_group = groups;
65+
this.no_bias = !offset_use_bias;
66+
this.layout = layout;
67+
this.num_deformable_group = num_deformable_group;
68+
this.adj = adj;
69+
var dshape = new int[kernel.Length + 2];
70+
dshape[layout.IndexOf('N')] = 1;
71+
dshape[layout.IndexOf('C')] = in_channels;
72+
73+
74+
var offsetshapes = _infer_weight_shape("convolution", new Shape(dshape));
75+
this.offset_weight = new Parameter("offset_weight", shape: offsetshapes[1], init: Initializer.Get(offset_weight_initializer), allow_deferred_init: true);
76+
if (offset_use_bias)
77+
{
78+
this.offset_bias = new Parameter("offset_bias", shape: offsetshapes[2], init: Initializer.Get(offset_bias_initializer), allow_deferred_init: true);
79+
}
80+
else
81+
{
82+
this.offset_bias = null;
83+
}
84+
var deformable_conv_weight_shape = new int[kernel.Length + 2];
85+
deformable_conv_weight_shape[0] = channels;
86+
deformable_conv_weight_shape[2] = kernel[0];
87+
deformable_conv_weight_shape[3] = kernel[1];
88+
this.deformable_conv_weight = new Parameter("deformable_conv_weight", shape: new Shape(deformable_conv_weight_shape), init: weight_initializer, allow_deferred_init: true);
89+
if (use_bias)
90+
{
91+
this.deformable_conv_bias = new Parameter("deformable_conv_bias", shape: new Shape(channels), init: bias_initializer, allow_deferred_init: true);
92+
}
93+
else
94+
{
95+
this.deformable_conv_bias = null;
96+
}
97+
98+
if (activation.HasValue)
99+
{
100+
this.act = new Activation(activation.Value);
101+
}
102+
else
103+
{
104+
this.act = null;
105+
}
106+
107+
this["deformable_conv_bias"] = deformable_conv_bias;
108+
this["deformable_conv_weight"] = deformable_conv_weight;
109+
this["offset_bias"] = offset_bias;
110+
this["offset_weight"] = offset_weight;
16111
}
17112

18113
public override string Alias()
@@ -22,7 +117,46 @@ public override string Alias()
22117

23118
public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol x, params NDArrayOrSymbol[] args)
24119
{
25-
return base.HybridForward(x, args);
120+
//object act;
121+
//object offset;
122+
//if (offset_bias == null)
123+
//{
124+
// offset = F.convolution(x, offset_weight, kernel: kernel, stride: strides, cudnn_off: true, this._kwargs_offset);
125+
//}
126+
//else
127+
//{
128+
// offset = F.convolution(x, offset_weight, offset_bias, cudnn_off: true, this._kwargs_offset);
129+
//}
130+
//if (deformable_conv_bias == null)
131+
//{
132+
// act = F.npx.deformable_convolution(data: x, offset: offset, weight: deformable_conv_weight, name: "fwd", this._kwargs_deformable_conv);
133+
//}
134+
//else
135+
//{
136+
// act = F.npx.deformable_convolution(data: x, offset: offset, weight: deformable_conv_weight, bias: deformable_conv_bias, name: "fwd", this._kwargs_deformable_conv);
137+
//}
138+
139+
//if (this.act)
140+
//{
141+
// using (var np_array(true))
142+
// {
143+
// act = this.act(act);
144+
// }
145+
//}
146+
147+
//return is_np_array() ? act : act.as_nd_ndarray();
148+
149+
throw new NotImplementedException();
150+
}
151+
152+
internal Shape[] _infer_weight_shape(string op_name, Shape data_shape)
153+
{
154+
var conv = sym.Convolution(_Symbol.Var("data", shape: data_shape), null, kernel: new Shape(kernel),
155+
num_filter: num_filter,
156+
stride: new Shape(strides), dilate: new Shape(dilation), pad: new Shape(padding), no_bias: no_bias,
157+
num_group: num_group, bias: null);
158+
159+
return conv.InferShapePartial(new Dictionary<string, Shape>()).Item1;
26160
}
27161
}
28162
}

src/MxNet/Gluon/NN/ConvLayers/PixelShuffle1D.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,19 @@ namespace MxNet.Gluon.NN
1919
{
2020
public class PixelShuffle1D : HybridBlock
2121
{
22+
public int factor;
23+
2224
public PixelShuffle1D(int factor)
2325
{
24-
throw new NotImplementedException();
26+
this.factor = factor;
2527
}
2628

2729
public override NDArrayOrSymbol HybridForward(NDArrayOrSymbol x, params NDArrayOrSymbol[] args)
2830
{
29-
throw new NotImplementedException();
31+
x = F.reshape(x, new Shape(-2, -6, -1, factor, -2));
32+
x = F.transpose(x, 0, 1, 3, 2);
33+
x = F.reshape(x, new Shape(-2, -2, -5));
34+
return x;
3035
}
3136
}
3237
}

0 commit comments

Comments
 (0)