forked from PlasmaControl/DESC
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_input_output.py
More file actions
361 lines (308 loc) · 11.9 KB
/
Copy pathtest_input_output.py
File metadata and controls
361 lines (308 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
"""Tests for reading/writing intput/output, both ascii and binary."""
import os
import pathlib
import shutil
import h5py
import numpy as np
import pytest
from desc.basis import FourierZernikeBasis
from desc.equilibrium import Equilibrium
from desc.grid import LinearGrid
from desc.io import InputReader, hdf5Reader, hdf5Writer, load
from desc.io.ascii_io import read_ascii, write_ascii
from desc.transform import Transform
from desc.utils import equals
@pytest.mark.unit
def test_vmec_input(tmpdir_factory):
"""Test converting VMEC to DESC input file."""
input_path = "./tests/inputs/input.DSHAPE"
tmpdir = tmpdir_factory.mktemp("desc_inputs")
tmp_path = tmpdir.join("input.DSHAPE")
shutil.copyfile(input_path, tmp_path)
ir = InputReader(cl_args=[str(tmp_path)])
vmec_inputs = ir.inputs
path = tmpdir.join("desc_from_vmec")
ir.write_desc_input(path, ir.inputs)
ir2 = InputReader(cl_args=[str(path)])
desc_inputs = ir2.inputs
for d, v in zip(desc_inputs, vmec_inputs):
d.pop("output_path")
v.pop("output_path")
assert all([equals(in1, in2) for in1, in2 in zip(vmec_inputs, desc_inputs)])
@pytest.mark.unit
def test_near_axis_input_files():
"""Test that DESC and VMEC input files generated by pyQSC give the same inputs."""
vmec_path = ".//tests//inputs//input.QSC_r2_5.5_vmec"
desc_path = ".//tests//inputs//input.QSC_r2_5.5_desc"
inputs_vmec = InputReader(vmec_path).inputs[-1]
inputs_desc = InputReader(desc_path).inputs[-1]
for arg in ["sym", "NFP", "Psi", "pressure", "current", "surface", "axis"]:
np.testing.assert_allclose(
inputs_desc[arg], inputs_vmec[arg], rtol=1e-6, atol=1e-8
)
@pytest.mark.unit
def test_vmec_input_surface_threshold():
"""Test ."""
path = ".//tests//inputs//input.QSC_r2_5.5_vmec"
surf_full = InputReader.parse_vmec_inputs(path)[-1]["surface"]
surf_trim = InputReader.parse_vmec_inputs(path, threshold=1e-6)[-1]["surface"]
assert surf_full.shape[0] > surf_trim.shape[0]
assert surf_full.shape[1] == surf_trim.shape[1] == 5
class TestInputReader:
"""Tests for the InputReader class."""
argv0 = []
argv1 = ["nonexistant_input_file"]
argv2 = ["./tests/inputs/MIN_INPUT"]
@pytest.mark.unit
def test_no_input_file(self):
"""Test an error is raised when no input file is given."""
with pytest.raises(NameError):
InputReader(cl_args=self.argv0)
@pytest.mark.unit
def test_nonexistant_input_file(self):
"""Test error is raised when nonexistent path is given."""
with pytest.raises(FileNotFoundError):
InputReader(cl_args=self.argv1)
@pytest.mark.unit
def test_min_input(self):
"""Test that minimal input is parsed correctly."""
ir = InputReader(cl_args=self.argv2)
assert ir.args.input_file[0] == self.argv2[0], "Input file name does not match"
assert ir.input_path == str(
pathlib.Path("./" + self.argv2[0]).resolve()
), "Path to input file is incorrect."
# Test defaults
assert ir.args.plot == 0, "plot is not default 0"
assert ir.args.quiet is False, "quiet is not default False"
assert ir.args.verbose == 1, "verbose is not default 1"
assert ir.args.numpy is False, "numpy is not default False"
assert (
os.environ["DESC_BACKEND"] == "jax"
), "numpy environment variable incorrect with default argument"
assert ir.args.version is False, "version is not default False"
assert (
len(ir.inputs[0]) == 28
), "number of inputs does not match number expected in MIN_INPUT"
# test equality of arguments
@pytest.mark.unit
def test_np_environ(self):
"""Test setting numpy backend via environment variable."""
argv = self.argv2 + ["--numpy"]
InputReader(cl_args=argv)
assert (
os.environ["DESC_BACKEND"] == "numpy"
), "numpy environment variable incorrect on use"
@pytest.mark.unit
def test_quiet_verbose(self):
"""Test setting of quiet and verbose options."""
ir = InputReader(self.argv2)
assert (
ir.inputs[0]["verbose"] == 1
), "value of inputs['verbose'] incorrect on no arguments"
argv = self.argv2 + ["-v"]
ir = InputReader(argv)
assert (
ir.inputs[0]["verbose"] == 2
), "value of inputs['verbose'] incorrect on verbose argument"
argv = self.argv2 + ["-vv"]
ir = InputReader(argv)
assert (
ir.inputs[0]["verbose"] == 3
), "value of inputs['verbose'] incorrect on double verbose argument"
argv = self.argv2 + ["-q"]
ir = InputReader(argv)
assert (
ir.inputs[0]["verbose"] == 0
), "value of inputs['verbose'] incorrect on quiet argument"
@pytest.mark.unit
def test_vmec_to_desc_input(self):
"""Test that we correctly convert a VMEC input file to DESC input file."""
# FIXME: maybe just store a file we know is converted correctly,
# and checksum compare a live conversion to it
pass
class MockObject:
"""Example object for saving/loading tests."""
def __init__(self):
self._io_attrs_ = ["a", "b", "c"]
@pytest.mark.unit
def test_writer_given_filename(writer_test_file):
"""Test writing to a given file by filename."""
writer = hdf5Writer(writer_test_file, "w")
assert writer.check_type(writer.target) is False
assert writer.check_type(writer.base) is True
assert writer._close_base_ is True
writer.close()
assert writer._close_base_ is False
@pytest.mark.unit
def test_writer_given_file(writer_test_file):
"""Test writing to given file instance."""
f = h5py.File(writer_test_file, "w")
writer = hdf5Writer(f, "w")
assert writer.check_type(writer.target) is True
assert writer.check_type(writer.base) is True
assert writer._close_base_ is False
assert writer._close_base_ is False
f.close()
@pytest.mark.unit
def test_writer_close_on_delete(writer_test_file):
"""Test that files are closed when writer is deleted."""
writer = hdf5Writer(writer_test_file, "w")
with pytest.raises(OSError):
newwriter = hdf5Writer(writer_test_file, "w")
del writer
newwriter = hdf5Writer(writer_test_file, "w")
del newwriter
@pytest.mark.unit
def test_writer_write_dict(writer_test_file):
"""Test writing dictionary to hdf5 file."""
thedict = {"1": 1, "2": 2, "3": 3}
writer = hdf5Writer(writer_test_file, "w")
writer.write_dict(thedict)
with pytest.raises(SyntaxError):
writer.write_dict(thedict, where="not a writable type")
writer.close()
f = h5py.File(writer_test_file, "r")
for key in thedict.keys():
assert key in f.keys()
assert f[key][()] == thedict[key]
f.close()
reader = hdf5Reader(writer_test_file)
dict1 = reader.read_dict()
assert dict1 == thedict
reader.close()
@pytest.mark.unit
def test_writer_write_list(writer_test_file):
"""Test writing list to hdf5 file."""
thelist = ["1", 1, "2", 2, "3", 3]
writer = hdf5Writer(writer_test_file, "w")
writer.write_list(thelist)
with pytest.raises(SyntaxError):
writer.write_list(thelist, where="not a writable type")
writer.close()
reader = hdf5Reader(writer_test_file)
list1 = reader.read_list()
assert list1 == thelist
reader.close()
@pytest.mark.unit
def test_writer_write_obj(writer_test_file):
"""Test writing objects to hdf5 file."""
mo = MockObject()
writer = hdf5Writer(writer_test_file, "w")
# writer should throw runtime warning if any save_attrs are undefined
with pytest.warns(RuntimeWarning):
writer.write_obj(mo)
writer.close()
writer = hdf5Writer(writer_test_file, "w")
for name in mo._io_attrs_:
setattr(mo, name, name)
writer.write_obj(mo)
groupname = "initial"
writer.write_obj(mo, where=writer.sub(groupname))
writer.close()
f = h5py.File(writer_test_file, "r")
for key in mo._io_attrs_:
assert key in f.keys()
assert groupname in f.keys()
initial = f[groupname]
for key in mo._io_attrs_:
assert key in initial.keys()
f.close()
@pytest.mark.unit
def test_reader_given_filename(reader_test_file):
"""Test opening a reader with a given filename."""
reader = hdf5Reader(reader_test_file)
assert reader.check_type(reader.target) is False
assert reader.check_type(reader.base) is True
assert reader._close_base_ is True
reader.close()
assert reader._close_base_ is False
@pytest.mark.unit
def test_reader_given_file(reader_test_file):
"""Test opening a reader from a given file instance."""
f = h5py.File(reader_test_file, "r")
reader = hdf5Reader(f)
assert reader.check_type(reader.target) is True
assert reader.check_type(reader.base) is True
assert reader._close_base_ is False
assert reader._close_base_ is False
f.close()
@pytest.mark.unit
def test_reader_read_obj(reader_test_file):
"""Test reading an object from hdf5 file."""
mo = MockObject()
reader = hdf5Reader(reader_test_file)
reader.read_obj(mo)
mo._io_attrs_ += "4"
with pytest.warns(RuntimeWarning):
reader.read_obj(mo)
del mo._io_attrs_[-1]
submo = MockObject()
reader.read_obj(submo, where=reader.sub("subgroup"))
for key in mo._io_attrs_:
assert hasattr(mo, key)
assert hasattr(submo, key)
@pytest.mark.unit
@pytest.mark.solve
def test_pickle_io(DSHAPE_current, tmpdir_factory):
"""Test saving and loading equilibrium in pickle format."""
tmpdir = tmpdir_factory.mktemp("desc_inputs")
tmp_path = tmpdir.join("solovev_test.pkl")
eqf = load(load_from=str(DSHAPE_current["desc_h5_path"]))
eqf.save(tmp_path, file_format="pickle")
peqf = load(tmp_path, file_format="pickle")
assert equals(eqf, peqf)
@pytest.mark.unit
@pytest.mark.solve
def test_ascii_io(DSHAPE_current, tmpdir_factory):
"""Test saving and loading equilibrium in ASCII format."""
tmpdir = tmpdir_factory.mktemp("desc_inputs")
tmp_path = tmpdir.join("solovev_test.txt")
eq1 = load(load_from=str(DSHAPE_current["desc_h5_path"]))[-1]
eq1.iota = eq1.get_profile("iota", grid=LinearGrid(30, 16, 0)).to_powerseries(
sym=True
)
write_ascii(tmp_path, eq1)
with pytest.warns(UserWarning):
eq2 = read_ascii(tmp_path)
assert np.allclose(eq1.R_lmn, eq2.R_lmn)
assert np.allclose(eq1.Z_lmn, eq2.Z_lmn)
assert np.allclose(eq1.L_lmn, eq2.L_lmn)
@pytest.mark.unit
def test_copy():
"""Test thing.copy() method of IOAble objects."""
basis = FourierZernikeBasis(2, 2, 2)
grid = LinearGrid(2, 2, 2)
transform1 = Transform(grid, basis, method="direct1")
transform2 = transform1.copy(deepcopy=False)
assert transform1.basis is transform2.basis
np.testing.assert_allclose(
transform1.matrices["direct1"][0][0][0],
transform2.matrices["direct1"][0][0][0],
rtol=1e-10,
atol=1e-10,
)
transform3 = transform1.copy(deepcopy=True)
assert transform1.basis is not transform3.basis
assert transform1.basis.eq(transform3.basis)
np.testing.assert_allclose(
transform1.matrices["direct1"][0][0][0],
transform3.matrices["direct1"][0][0][0],
rtol=1e-10,
atol=1e-10,
)
@pytest.mark.unit
def test_save_none(tmpdir_factory):
"""Test that None attributes are saved/loaded correctly."""
tmpdir = tmpdir_factory.mktemp("none_test")
eq = Equilibrium()
eq._iota = None
eq.save(tmpdir + "none_test.h5")
eq1 = load(tmpdir + "none_test.h5")
assert eq1.iota is None
@pytest.mark.unit
def test_load_eq_without_current():
"""Test that loading an eq from DESC < 0.6.0 works correctly."""
desc_no_current_path = ".//tests//inputs//DSHAPE_output_saved_without_current.h5"
with pytest.warns(RuntimeWarning):
eq = load(desc_no_current_path)[-1]
assert eq.current is None