Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cdf661a

Browse files
committed
Initial DotSim commit
Dot Simulator environment, with dot tracing example and plotting tool. Threw in cue-reward in case it's helpful.
1 parent c9f3e9d commit cdf661a

9 files changed

Lines changed: 1242 additions & 24 deletions

File tree

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import numpy as np
2+
import glob
3+
import sys
4+
5+
import matplotlib.pyplot as plt
6+
7+
# Define grid dimensions globally
8+
ROWS = 28
9+
COLS = 28
10+
11+
12+
def plotGrids(gridData):
13+
if(gridData.shape[0] % ROWS != 0 or gridData.shape[1] != COLS):
14+
raise('Incompatible grid dimensionality: check data and assumed dimensions.')
15+
16+
grids = gridData.shape[0]//ROWS
17+
18+
print('Reshaping into', grids, 'grids of shape (', ROWS, ',', COLS, ')')
19+
gridData = gridData.reshape((grids, ROWS, COLS))
20+
21+
plotAnotherRange = True
22+
23+
while(plotAnotherRange):
24+
start = -1
25+
end = 1
26+
print('Select the range of iterations to generate grid plots from.')
27+
print('0 means plot all iterations.')
28+
while((start < 0 or grids-1 < start) or (end < 1 or grids < end)):
29+
start = int(input('Start: '))
30+
31+
# If start is set to zero, plot everything.
32+
if(start == 0):
33+
continue
34+
35+
end = int(input('End: '))
36+
37+
if start == 0:
38+
print('\nPlotting whole shebang!')
39+
else:
40+
print('\nPlotting range from iteration', start, 'to', end)
41+
42+
# Plotting time!
43+
plt.figure()
44+
plt.ion()
45+
plt.imshow(gridData[start], cmap='hot', interpolation='nearest')
46+
plt.colorbar()
47+
plt.pause(.001) # Pause so that that GUI can do its thing.
48+
for g in gridData[start+1:end]:
49+
plt.imshow(g, cmap='hot', interpolation='nearest')
50+
plt.pause(.001) # Pause so that that GUI can do its thing.
51+
52+
plotAnotherRange = str.lower(input('Plot another range? (y/n): ')) == 'y'
53+
54+
55+
def plotRewards(rewData, fname):
56+
cumRewards = np.cumsum(rewData)
57+
tsteps = np.array(range(len(cumRewards)))
58+
59+
# Plotting time!
60+
plt.figure()
61+
plt.plot(tsteps, cumRewards)
62+
plt.xlabel('Timesteps')
63+
plt.ylabel('Cumulative Reward')
64+
plt.title("Cumulative Reward by Iteration")
65+
plt.savefig(fname[0:-4] + '.png', dpi=200)
66+
plt.pause(.001) # Pause so that that GUI can do its thing.
67+
68+
69+
def plotPerformance(perfData, fname):
70+
71+
# Set bins to a tenth of the episodes, rounded up.
72+
binIdx = np.array(range(len(perfData)))//10
73+
bins = np.bincount(binIdx, perfData).astype('uint32')
74+
75+
# Plotting time!
76+
plt.figure()
77+
plt.bar(np.unique(binIdx), bins, color='seagreen')
78+
plt.xlabel('Episode Bins')
79+
plt.ylabel('Number of Intercepts')
80+
plt.title("Interception Performance Across Episodes")
81+
plt.savefig(fname[0:-4] + '.png', dpi=200)
82+
plt.pause(.001) # Pause so that that GUI can do its thing.
83+
84+
85+
def main():
86+
"""
87+
File types:
88+
89+
0) grid - the 2D matrix observation
90+
1) reward - list of rewards per iteration
91+
2) performance - list of performance values
92+
"""
93+
fileType = 0 # default to grid
94+
95+
# By default, we'll search the examples directory, but tweak as needed.
96+
files = glob.glob('../../examples/*/out/*csv')
97+
98+
if len(files) == 0:
99+
print('Could not find any csv files. Exiting...')
100+
sys.exit()
101+
102+
plotAnotherFile = True
103+
104+
while plotAnotherFile:
105+
print('Select the file to generate grid plots from.')
106+
for i,f in enumerate(files):
107+
print(str(i), '-', f)
108+
109+
# Select the intended file.
110+
sel = -1
111+
while sel < 0 or len(files) < sel:
112+
sel = int(input('\nFile selection: '))
113+
114+
fileToPlot = files[sel]
115+
116+
# Check file type
117+
if(0 < fileToPlot.find('grid')):
118+
print('\nFound \'grid\' in name: assuming a grid file type.')
119+
fileType = 0
120+
elif(0 < fileToPlot.find('rew')):
121+
print('\nFound \'rew\' in name: assuming a reward file type.')
122+
fileType = 1
123+
124+
elif(0 < fileToPlot.find('perf')):
125+
print('\nFound \'perf\' in name: assuming a performance file type.')
126+
fileType = 2
127+
else:
128+
print('\nUnknown file type. Which type are we plotting?')
129+
print('\n0) grid\n1) reward\n2) performance')
130+
fileType = -1
131+
while fileType < 0 or 2 < fileType:
132+
fileType = int(input('\nFile type: '))
133+
134+
print('\nPlotting: ', fileToPlot)
135+
data = np.genfromtxt(fileToPlot, delimiter=',')
136+
137+
# Plot by file type
138+
if(fileType == 0):
139+
plotGrids(data)
140+
elif(fileType == 1):
141+
plotRewards(data, fileToPlot)
142+
elif(fileType == 2):
143+
plotPerformance(data, fileToPlot)
144+
else:
145+
print('ERROR: Unknown file type')
146+
147+
plotAnotherFile = str.lower(input('Plot another file? (y/n): ')) == 'y'
148+
149+
150+
151+
if __name__ == '__main__':
152+
main()
153+

bindsnet/encoding/encodings.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional
22

33
import torch
4+
import numpy as np
45

56

67
def single(
@@ -15,21 +16,21 @@ def single(
1516
"""
1617
Generates timing based single-spike encoding. Spike occurs earlier if the
1718
intensity of the input feature is higher. Features whose value is lower than
18-
threshold is remain silent.
19+
threshold remain silent.
1920
2021
:param datum: Tensor of shape ``[n_1, ..., n_k]``.
2122
:param time: Length of the input and output.
2223
:param dt: Simulation time step.
23-
:param sparsity: Sparsity of the input representation. 0 for no spikes and 1 for all
24-
spikes.
24+
:param sparsity: Sparsity of the input representation. 0 for no spikes
25+
and 1 for all spikes.
2526
:return: Tensor of shape ``[time, n_1, ..., n_k]``.
2627
"""
2728
time = int(time / dt)
2829
shape = list(datum.shape)
29-
datum = torch.tensor(datum)
30-
quantile = torch.quantile(datum, 1 - sparsity)
31-
s = torch.zeros([time, *shape], device=device)
32-
s[0] = torch.where(datum > quantile, torch.ones(shape), torch.zeros(shape))
30+
datum = np.copy(datum)
31+
quantile = np.quantile(datum, 1 - sparsity)
32+
s = np.zeros([time, *shape], device=device)
33+
s[0] = np.where(datum > quantile, np.ones(shape), np.zeros(shape))
3334
return torch.Tensor(s).byte()
3435

3536

@@ -140,7 +141,7 @@ def poisson(
140141

141142
# Create Poisson distribution and sample inter-spike intervals
142143
# (incrementing by 1 to avoid zero intervals).
143-
dist = torch.distributions.Poisson(rate=rate, validate_args=False)
144+
dist = torch.distributions.Poisson(rate=rate)
144145
intervals = dist.sample(sample_shape=torch.Size([time + 1]))
145146
intervals[:, datum != 0] += (intervals[:, datum != 0] == 0).float()
146147

bindsnet/environment/README.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Important
2+
3+
## Dot Simulator
4+
5+
### Overview
6+
7+
This simulator lets us generate dots and make them move in a configurable 2D space, providing a visual to a neural network for training in experiments.
8+
9+
Specifically, this generates a grid for each timestep, where a specified number of points have values of 1 with fading tails ("decay"), designating the current positions and movements of their corresponding dots. All other points are set to 0. From timestep to timestep, the dots either remain where they are or move one space.
10+
11+
The 2D observation of the current state is provided every step, as well as the reward, completion flag, and sucessful interception flag. It may be helpful to scale the grid values when encoding them as spike trains.
12+
13+
The intended objective is to train a network to use its "network dot" to trace or intercept a moving "target" dot. But this simulator is designed to easily adapt to multiple kinds of experiments.
14+
15+
16+
### Dot Movement
17+
18+
By default, there is a single "target" dot that moves in a random direction every timestep (or it can stay still, which can be disabled), and as it moves, it leaves a tunable "decay" in the form of a fading tail. The simulator supports four directions of movement by default (up/down/left/right) by default, as well as remaining still, but the diag parameters allows diagonal movement for more complexity. The rate of the target's randomized movement can also be modified (ie. random direction every timestep or only change direction so often).
19+
20+
The simulator supports multiple bounds-handling schemes. By default, dots will simply not move past the edges. Alternatively, the bound_hand parameter can be set to 'bounce', for a geometric reflection off the edges, or 'trans' which will have a mirrored result: a geometric translation to the opposite side of the grid.
21+
22+
To add further complexity, additional targets can be added as desired via the dots parameter, and the herrs parameter can be set to generate multiple "red herrings" as distraction dots. The speed of the dots' movements can also be set; it is 1 by default.
23+
24+
<p align="middle">
25+
<img src="https://github.com/Hananel-Hazan/bindsnet/blob/master/docs/BindsNET%20benchmark.png" alt="BindsNET%20Benchmark" width="503" height="403">
26+
</p>
27+
>The grid visuals provided by the render function will double the value of the network dot; this is a visual aid only, invisible to the network.
28+
29+
30+
### Reward Functions
31+
32+
This simulator supports multiple reward functions (aka. fitness functions):
33+
- Euclidean (fit_func='euc'): the default option, this function computes the Euclidean (aka. Pythagorean) distance between the network dot and the target dot.
34+
- Displacement (fit_func='disp'): this option computes the x,y displacement of the network dot with respect to the target dot, returning an x,y tuple. Currently, BindsNET only supports single reward values. To use this one, either be creative or update the network code...
35+
- Range Rings (fit_func='rng'): this option uses the Euclidean distance and groups it into range rings. The radial distance of the range rings can be set by the ring_size parameter.
36+
- Directional (fit_func='dir'): the directional option checks to see if the network's decision moved its dot closer, laterally, or further away from the target dot's prior position (ie. before applying movement this timestep) and returns a +1, 0, or -1 accordingly.
37+
38+
Additionally, upon a successful intercept, the network will receive +10 if the bullseye parameter is active, and its dot will be teleported to another random location if the teleport parameter is active.
39+
40+
>In the event multiple target dots are generated, the fitness functions only compute rewards with respect to the first target dot.
41+
42+
43+
### Additional Features
44+
45+
The environment can take a seed for random number generation in python, numpy, and Pytorch; otherwise, it will generate and save a new seed based on the current system time.
46+
47+
As this simulator was developed in Anaconda Spyder on Windows, it can be run from Windows or Linux. Since environments handle plotting differently, and experiments can sometimes be terminated prematurely, this environment supports the recording of grid observations in text files and post-op plotting. Live rendering can also be disabled via the mute parameter, and a text-based alternative using pandas dataframe formatting can be enabled via the pandas parameter.
48+
49+
Filenames and file paths can be specified for recording grid observations. By default, the filenames will be "grid" followed by "s#_$.csv" where # is the random seed used and $ is the current file number. addFileSuffix(suffix) adds the provided suffix (typically used for "train" or "test") to the filename, and changeFileSuffix(sFrom, sTo) will find sFrom in the filename and replace it with sTo.
50+
51+
To ensure that files do not become too large to either be saved or be practically useful, cycleOutFiles(newInt) can be used to cycle the current save file, incrementing the file number suffix, or resetting it if newInt is set to a positive number.
52+
53+
Post-op plotting is supported by dotTrace_plotter.py in the analysis directory. By default, this tool searches the examples directory for csvs in "out" directories, but that path can be easily changed. It supports plotting ranges of grid observations, reward plots, and performance plots. See below for an example of recording reward and performance data for plotting purposes.
54+
55+
56+
### Example
57+
See dot_tracing.py for an example in using the Dot Simulator for training an SNN in BindsNET.
58+
59+
dot_tracing trains a basic RNN network on the dot simulator and demonstrates how to record reward and performance data (if desired) and plot spiking activity via monitors.
60+
61+

0 commit comments

Comments
 (0)