Audacity 3.2.0
MirTestUtils.h
Go to the documentation of this file.
1/* SPDX-License-Identifier: GPL-2.0-or-later */
2/*!********************************************************************
3
4 Audacity: A Digital Audio Editor
5
6 MirTestUtils.h
7
8 Matthieu Hodgkinson
9
10**********************************************************************/
11#pragma once
12
13#include "MirTypes.h"
14
15#include <algorithm>
16#include <array>
17#include <cassert>
18#include <fstream>
19#include <functional>
20#include <numeric>
21#include <string>
22#include <vector>
23
24namespace MIR
25{
26
27// Some tests, such as benchmarking and visualization, are not meant to be run
28// on CI. This variable is used to disable them.
29static constexpr auto runLocally = false;
30
31struct RocInfo
32{
33 const double areaUnderCurve;
34 const double threshold;
35};
36
50template <typename Result>
52GetRocInfo(std::vector<Result> results, double allowedFalsePositiveRate = 0.)
53{
54 const auto truth = std::mem_fn(&Result::truth);
55 const auto falsity = std::not_fn(truth);
56
57 // There is at least one positive and one negative sample.
58 assert(any_of(results.begin(), results.end(), truth));
59 assert(any_of(results.begin(), results.end(), falsity));
60
61 assert(allowedFalsePositiveRate >= 0. && allowedFalsePositiveRate <= 1.);
62 allowedFalsePositiveRate = std::clamp(allowedFalsePositiveRate, 0., 1.);
63
64 // Sort the results by score, descending.
65 std::sort(results.begin(), results.end(), [](const auto& a, const auto& b) {
66 return a.score > b.score;
67 });
68
69 const auto size = results.size();
70 const auto numPositives = count_if(results.begin(), results.end(), truth);
71 const auto numNegatives = size - numPositives;
72
73 // Find true and false positive rates for various score thresholds.
74 // True positive and false positive counts are nondecreasing with i,
75 // therefore if false positive rate has increased at some i, true positive
76 // rate has not decreased.
77 std::vector<double> truePositiveRates;
78 truePositiveRates.reserve(size);
79 std::vector<double> falsePositiveRates;
80 falsePositiveRates.reserve(size);
81 size_t numTruePositives = 0;
82 size_t numFalsePositives = 0;
83 for (const auto& result : results)
84 {
85 if (result.truth)
86 ++numTruePositives;
87 else
88 ++numFalsePositives;
89 truePositiveRates.push_back(
90 static_cast<double>(numTruePositives) / numPositives);
91 falsePositiveRates.push_back(
92 static_cast<double>(numFalsePositives) / numNegatives);
93 }
94
95 // Now find the area under the non-decreasing curve with FPR as x-axis,
96 // TPR as y, and i as a parameter. (This curve is within a square with unit
97 // side.)
98 double auc = 0.;
99 for (size_t i = 0; i <= size; ++i)
100 {
101 const auto leftFpr = i == 0 ? 0. : falsePositiveRates[i - 1];
102 const auto rightFpr = i == size ? 1. : falsePositiveRates[i];
103 const auto leftTpr = i == 0 ? 0. : truePositiveRates[i - 1];
104 const auto rightTpr = i == size ? 1. : truePositiveRates[i];
105 const auto trapezoid = (rightTpr + leftTpr) * (rightFpr - leftFpr) / 2.;
106 assert(trapezoid >= 0); // See comments above
107 auc += trapezoid;
108 }
109
110 // Find the parameter at which the x coordinate exceeds the allowed FPR.
111 const auto it = std::upper_bound(
112 falsePositiveRates.begin(), falsePositiveRates.end(),
113 allowedFalsePositiveRate);
114
115 if (it == falsePositiveRates.end())
116 // All breakpoints satify the constraint. Return the least score.
117 return { auc, results.back().score };
118 else if (it == falsePositiveRates.begin())
119 // No breakpoint satisfies the constraint. Return the greatest score.
120 return { auc, results.front().score };
121
122 // For threshold, use the score halfway between the last breakpoint that
123 // satisfies the constraint and the first breakpoint that doesn't.
124 const auto index = it - falsePositiveRates.begin();
125 const auto threshold = (results[index - 1].score + results[index].score) / 2;
126
127 return { auc, threshold };
128}
129
130void ProgressBar(int width, int percent);
131
132template <typename T>
134 std::ofstream& ofs, const std::vector<T>& v, const char* name)
135{
136 ofs << name << " = [";
137 std::for_each(v.begin(), v.end(), [&](T x) { ofs << x << ","; });
138 ofs << "]\n";
139}
140
142{
143 double factor;
144 double remainder;
145};
146
159OctaveError GetOctaveError(double expected, double actual);
160
161// Reproducible benchmarking must use the same input. We use this to make sure
162// that it does.
163template <int bufferSize = 1024> float GetChecksum(const MirAudioReader& source)
164{
165 // Sum samples to checksum.
166 float checksum = 0.f;
167 long long start = 0;
168 std::array<float, bufferSize> buffer;
169 while (true)
170 {
171 const auto numSamples =
172 std::min<long long>(bufferSize, source.GetNumSamples() - start);
173 if (numSamples == 0)
174 break;
175 source.ReadFloats(buffer.data(), start, numSamples);
176 checksum +=
177 std::accumulate(buffer.begin(), buffer.begin() + numSamples, 0.f);
178 start += numSamples;
179 }
180 return checksum;
181}
182
183} // namespace MIR
wxString name
Definition: TagsEditor.cpp:166
virtual void ReadFloats(float *buffer, long long where, size_t numFrames) const =0
virtual long long GetNumSamples() const =0
void PrintPythonVector(std::ofstream &ofs, const std::vector< T > &v, const char *name)
Definition: MirTestUtils.h:133
OctaveError GetOctaveError(double expected, double actual)
Gets the tempo detection octave error, as defined in section 5. of Schreiber, H., Urbano,...
float GetChecksum(const MirAudioReader &source)
Definition: MirTestUtils.h:163
void ProgressBar(int width, int percent)
RocInfo GetRocInfo(std::vector< Result > results, double allowedFalsePositiveRate=0.)
Definition: MirTestUtils.h:52
static constexpr auto runLocally
Definition: MirTestUtils.h:29
const double threshold
Definition: MirTestUtils.h:34
const double areaUnderCurve
Definition: MirTestUtils.h:33