Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a146cc9

Browse files
committed
Added a script to download a set of 100 matrices to be used during training
1 parent c042bc8 commit a146cc9

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
download.py
3+
4+
EPCC, The University of Edinburgh
5+
6+
(c) 2023 The University of Edinburgh
7+
8+
Contributing Authors:
9+
Christodoulos Stylianou ([email protected])
10+
11+
Licensed under the Apache License, Version 2.0 (the "License");
12+
you may not use this file except in compliance with the License.
13+
You may obtain a copy of the License at
14+
15+
http://www.apache.org/licenses/LICENSE-2.0
16+
17+
Unless required by applicable law or agreed to in writing, software
18+
distributed under the License is distributed on an "AS IS" BASIS,
19+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20+
See the License for the specific language governing permissions and
21+
limitations under the License.
22+
"""
23+
24+
"""
25+
The following script is used to download 100 real and square matrices
26+
with a maximum of 100k Non-zeros from SuiteSparse collection.
27+
"""
28+
29+
import ssgetpy
30+
import random
31+
import os
32+
33+
OUTDIR = os.path.dirname(os.path.abspath(__file__)) + "/" + "matrices"
34+
MAX_MATICES = 10000
35+
SET_MATRICES = 100
36+
MAX_NNZ = 100000
37+
38+
# Set seed to randomize list in the same way
39+
random.seed(0)
40+
41+
matrices = ssgetpy.matrix.MatrixList()
42+
dtypes = ["Real"]
43+
44+
for dtype in dtypes:
45+
result = ssgetpy.search(
46+
dtype=dtype.lower(), nzbounds=(0, MAX_NNZ), limit=MAX_MATICES
47+
)
48+
print(dtype + " matrices: " + str(len(result)))
49+
50+
# Shuffle the list
51+
random.shuffle(result)
52+
53+
ctr = 0
54+
for matrix in result:
55+
if len(matrices) == SET_MATRICES:
56+
break
57+
58+
if (matrix.rows != matrix.cols) or (matrix.id in [230, 231]):
59+
# Matrices 230 & 231 are Skew-Symmetric
60+
result.remove(matrix)
61+
ctr += 1
62+
else:
63+
matrices.append(matrix)
64+
65+
print("\tRemoved " + str(ctr) + " non-square " + dtype + " matrices")
66+
67+
print("Total Matrices in the set: " + str(len(matrices)))
68+
69+
matrices.download(destpath=OUTDIR, extract=True)

0 commit comments

Comments
 (0)