27
27
#include < MorpheusOracle_TypeTraits.hpp>
28
28
29
29
#include < chrono>
30
+ #include < limits>
30
31
31
32
namespace Morpheus {
32
33
namespace Oracle {
@@ -47,29 +48,37 @@ void tune_multiply(
47
48
48
49
vector x (mat.ncols (), value_type (2 )), y (mat.nrows (), value_type (0 ));
49
50
50
- auto mat_mirror = Morpheus::create_mirror<typename Matrix::execution_space>(mat);
51
+ auto mat_mirror =
52
+ Morpheus::create_mirror<typename Matrix::execution_space>(mat);
51
53
Morpheus::copy (mat, mat_mirror);
52
54
53
55
auto mat_mirror_h = Morpheus::create_mirror_container (mat_mirror);
54
56
Morpheus::copy (mat_mirror, mat_mirror_h);
55
57
56
58
size_t current_format = Morpheus::Oracle::RunFirstTuner::INVALID_FORMAT_STATE;
59
+ Morpheus::conversion_error_e status = Morpheus::CONV_SUCCESS;
57
60
while (!tuner.finished ()) {
58
61
if (current_format != tuner.format_count ()) {
59
62
// Convert only when we start a new format_count
60
- Morpheus::convert<Kokkos::Serial>(mat_mirror_h, tuner.format_count ());
61
- mat_mirror.activate (mat_mirror_h.active_index ());
62
- mat_mirror.resize (mat_mirror_h);
63
- Morpheus::copy (mat_mirror_h, mat_mirror);
64
- current_format = tuner.format_count ();
63
+ status = Morpheus::convert<Morpheus::Serial>(mat_mirror_h,
64
+ tuner.format_count ());
65
+ if (status == Morpheus::CONV_SUCCESS) {
66
+ mat_mirror.activate (mat_mirror_h.active_index ());
67
+ mat_mirror.resize (mat_mirror_h);
68
+ Morpheus::copy (mat_mirror_h, mat_mirror);
69
+ current_format = tuner.format_count ();
70
+ }
65
71
}
72
+ double runtime;
73
+ if (status == Morpheus::CONV_SUCCESS) {
74
+ auto start = std::chrono::steady_clock::now ();
75
+ Morpheus::multiply<ExecSpace>(mat_mirror, x, y, true );
76
+ auto end = std::chrono::steady_clock::now ();
66
77
67
- auto start = std::chrono::steady_clock::now ();
68
- Morpheus::multiply<ExecSpace>(mat_mirror, x, y, true );
69
- auto end = std::chrono::steady_clock::now ();
70
-
71
- double runtime = std::chrono::duration_cast<ns>(end - start).count () * 1e-9 ;
72
-
78
+ runtime = std::chrono::duration_cast<ns>(end - start).count () * 1e-9 ;
79
+ } else {
80
+ runtime = std::numeric_limits<double >::max ();
81
+ }
73
82
tuner.register_run (runtime);
74
83
++tuner;
75
84
}
0 commit comments