Audacity 3.2.0
TatumQuantizationFitBenchmarking.cpp
Go to the documentation of this file.
1#include "MirFakes.h"
2#include "MirTestUtils.h"
4#include "WavMirAudioReader.h"
5
6#include <catch2/catch.hpp>
7#include <chrono>
8#include <fstream>
9#include <iomanip>
10#include <iostream>
11#include <sstream>
12
13#define USE_FILESYSTEM (__has_include(<filesystem>) && _WIN32)
14
15#if USE_FILESYSTEM
16# include <filesystem>
17#endif
18
19namespace MIR
20{
21namespace
22{
23const auto datasetRoot =
24 std::string(CMAKE_CURRENT_SOURCE_DIR) + "/benchmarking-dataset";
25
26std::vector<std::string> GetBenchmarkingAudioFiles()
27{
28 std::vector<std::string> files;
29#if USE_FILESYSTEM
30 namespace fs = std::filesystem;
31 for (const auto& entry : fs::directory_iterator(datasetRoot))
32 for (const auto& subEntry : fs::recursive_directory_iterator(entry))
33 if (
34 subEntry.is_regular_file() && subEntry.path().extension() == ".mp3")
35 files.push_back(subEntry.path().string());
36#else
37 // Recursively find all files in the dataset directory with .mp3 extension,
38 // not using std::filesystem:
39 // https://stackoverflow.com/questions/612097/how-can-i-get-the-list-of-files-in-a-directory-using-c-or-c
40 const auto command = "find -H " + datasetRoot + " -type f -name '*.mp3' -print";
41 FILE* pipe = popen(command.c_str(), "r");
42 if (!pipe)
43 throw std::runtime_error("popen() failed!");
44 constexpr auto bufferSize = 512;
45 char buffer[bufferSize];
46 while (fgets(buffer, bufferSize, pipe) != nullptr)
47 {
48 std::string file(buffer);
49 file.erase(file.find_last_not_of("\n") + 1);
50 files.push_back(file);
51 }
52 const auto returnCode = pclose(pipe);
53 if (returnCode != 0)
54 throw std::runtime_error("pclose() failed!");
55#endif
56 std::sort(files.begin(), files.end());
57 return files;
58}
59
60std::string Pretty(const std::string& filename)
61{
62 // Remove the dataset root from the filename ...
63 const auto datasetRootLength = datasetRoot.length();
64 auto tmp = filename.substr(datasetRootLength + 1);
65 // ... and now the .mp3 extension:
66 tmp = tmp.substr(0, tmp.length() - 4);
67 // Replace backslashes with forward slashes:
68 std::replace(tmp.begin(), tmp.end(), '\\', '/');
69 return tmp;
70}
71} // namespace
72
73TEST_CASE("GetRocInfo")
74{
75 // We use the AUC as a measure of the classifier's performance. With a
76 // suitable data set, this helps us detect regressions, and guide fine-tuning
77 // of the algorithm. This test should help understand how it works and also
78 // make sure that we've implemented that metric correctly :)
79
80 struct Sample
81 {
82 bool truth;
83 double score;
84 };
85
86 using Samples = std::vector<Sample>;
87
88 struct Expected
89 {
90 double areaUnderCurve;
91 double threshold;
92 };
93
94 struct TestCase
95 {
96 const Samples samples;
97 const double allowedFalsePositiveRate;
98 const Expected expected;
99 };
100
101 const std::vector<TestCase> testCases {
102 // Classifier is upside down. We don't tolerate false positives. The
103 // returned threshold is then 100 will satisfy this, but the TPR will also
104 // be 0 ...
105 TestCase { Samples { { true, 100. }, { false, 200. } }, 0.,
106 Expected { 0., 200. } },
107
108 // Classifier is still upside down. We'll get true positives only if we
109 // accept an FPR of 1.
110 TestCase { Samples { { true, 100. }, { false, 200. } }, 1.,
111 Expected { 0., 100. } },
112
113 // Now we have a classifier that works. We don't accept false positives.
114 TestCase { Samples { { false, 100. }, { false, 150. }, { true, 200. } },
115 0., Expected { 1., 175. } },
116
117 // A random classifier, which should have an AUC of 0.75.
118 TestCase {
119 Samples { { false, 1. }, { true, 2. }, { false, 3. }, { true, 4. } },
120 0.5, Expected { .75, 1.5 } },
121 };
122
123 for (const auto& testCase : testCases)
124 {
125 const auto roc =
126 GetRocInfo(testCase.samples, testCase.allowedFalsePositiveRate);
127 REQUIRE(roc.areaUnderCurve == testCase.expected.areaUnderCurve);
128 REQUIRE(roc.threshold == testCase.expected.threshold);
129 }
130}
131
132TEST_CASE("GetChecksum")
133{
134 constexpr auto bufferSize = 5;
135 const auto checksum = GetChecksum<bufferSize>(SquareWaveMirAudioReader {});
136 REQUIRE(checksum == 0.);
137}
138
139auto ToString(const std::optional<TimeSignature>& ts)
140{
141 if (ts.has_value())
142 switch (*ts)
143 {
145 return std::string("2/2");
147
148 return std::string("4/4");
150 return std::string("3/4");
152 return std::string("6/8");
153 default:
154 return std::string("none");
155 }
156 else
157 return std::string("none");
158}
159
160TEST_CASE("TatumQuantizationFitBenchmarking")
161{
162 // For this test to run, you will need to set `runLocally` to `true`, and
163 // you'll also need the benchmarking sound files. To get these, just open
164 // `download-benchmarking-dataset.html` in a browser. This will download a
165 // zip file that you'll need to extract and place in a `benchmarking-dataset`
166 // directory under this directory.
167
168 // Running this test will update
169 // `TatumQuantizationFitBenchmarkingOutput/summary.txt`. The summary contains
170 //
171 // 1. the AUC metric for regression-testing,
172 // 2. the strict- and lenient-mode thresholds,
173 // 3. the octave-error RMS (Schreiber, H., et al. (2020)), and
174 // 4. the hash of the audio files used.
175 //
176 // The AUC can only be used for comparison if the hash doesn't change. At the
177 // time of writing, the benchmarking can only conveniently be run on the
178 // author's machine (Windows), because the files used are not
179 // redistributable. Setting up a redistributable dataset that can be used by
180 // other developers is currently being worked on.
181
182 // We only observe the results for the most lenient classifier. The other,
183 // stricter classifier will yield the same results, only with fewer false
184 // positives.
185 if (!runLocally)
186 return;
187
188 constexpr auto tolerance = FalsePositiveTolerance::Lenient;
189 constexpr int progressBarWidth = 50;
190 const auto audioFiles = GetBenchmarkingAudioFiles();
191 std::stringstream sampleValueCsv;
192 sampleValueCsv
193 << "truth,score,tatumRate,bpm,ts,octaveFactor,octaveError,lag,filename\n";
194
195 float checksum = 0.f;
196 struct Sample
197 {
198 bool truth;
199 double score;
200 std::optional<OctaveError> octaveError;
201 };
202 std::vector<Sample> samples;
203 const auto numFiles = audioFiles.size();
204 auto count = 0;
205 std::chrono::milliseconds computationTime { 0 };
206 std::transform(
207 audioFiles.begin(), audioFiles.begin() + numFiles,
208 std::back_inserter(samples), [&](const std::string& wavFile) {
209 const WavMirAudioReader audio { wavFile };
210 checksum += GetChecksum(audio);
211 QuantizationFitDebugOutput debugOutput;
212 std::function<void(double)> progressCb;
213 const auto now = std::chrono::steady_clock::now();
214 GetMusicalMeterFromSignal(audio, tolerance, progressCb, &debugOutput);
215 computationTime +=
216 std::chrono::duration_cast<std::chrono::milliseconds>(
217 std::chrono::steady_clock::now() - now);
218 ProgressBar(progressBarWidth, 100 * count++ / numFiles);
219 const auto expected = GetBpmFromFilename(wavFile);
220 const auto truth = expected.has_value();
221 const std::optional<OctaveError> error =
222 truth && debugOutput.bpm > 0 ?
223 std::make_optional(GetOctaveError(*expected, debugOutput.bpm)) :
224 std::nullopt;
225 sampleValueCsv << (truth ? "true" : "false") << ","
226 << debugOutput.score << ","
227 << 60. * debugOutput.tatumQuantization.numDivisions /
228 debugOutput.audioFileDuration
229 << "," << debugOutput.bpm << ","
230 << ToString(debugOutput.timeSignature) << ","
231 << (error.has_value() ? error->factor : 0.) << ","
232 << (error.has_value() ? error->remainder : 0.) << ","
233 << debugOutput.tatumQuantization.lag << ","
234 << Pretty(wavFile) << "\n";
235 return Sample { truth, debugOutput.score, error };
236 });
237
238 {
239 std::ofstream timeMeasurementFile { "./timeMeasurement.txt" };
240 timeMeasurementFile << computationTime.count() << "ms\n";
241 }
242
243 // AUC of ROC curve. Tells how good our loop/not-loop clasifier is.
244 const auto rocInfo = GetRocInfo(
245 samples, loopClassifierSettings.at(tolerance).allowedFalsePositiveRate);
246
247 const auto strictThreshold =
250 .allowedFalsePositiveRate)
251 .threshold;
252
253 // Get RMS of octave errors. Tells how good the BPM estimation is.
254 const auto octaveErrors = std::accumulate(
255 samples.begin(), samples.end(), std::vector<double> {},
256 [&](std::vector<double> octaveErrors, const Sample& sample)
257 {
258 if (sample.octaveError.has_value())
259 octaveErrors.push_back(sample.octaveError->remainder);
260 return octaveErrors;
261 });
262 const auto octaveErrorStd = std::sqrt(
263 std::accumulate(
264 octaveErrors.begin(), octaveErrors.end(), 0.,
265 [&](double sum, double octaveError)
266 { return sum + octaveError * octaveError; }) /
267 octaveErrors.size());
268
269 constexpr auto previousAuc = 0.9312244897959182;
270 const auto classifierQualityHasChanged =
271 std::abs(rocInfo.areaUnderCurve - previousAuc) >= 0.01;
272
273 // Only update the summary if the figures have significantly changed.
274 if (classifierQualityHasChanged)
275 {
276 std::ofstream summaryFile {
277 std::string(CMAKE_CURRENT_SOURCE_DIR) +
278 "/TatumQuantizationFitBenchmarkingOutput/summary.txt"
279 };
280 summaryFile << std::setprecision(
281 std::numeric_limits<double>::digits10 + 1)
282 << "AUC: " << rocInfo.areaUnderCurve << "\n"
283 << "Strict Threshold (Minutes-and-Seconds): "
284 << strictThreshold << "\n"
285 << "Lenient Threshold (Beats-and-Measures): "
286 << rocInfo.threshold << "\n"
287 << "Octave error RMS: " << octaveErrorStd << "\n"
288 << "Audio file checksum: " << checksum << "\n";
289 // Write sampleValueCsv to a file.
290 std::ofstream sampleValueCsvFile {
291 std::string(CMAKE_CURRENT_SOURCE_DIR) +
292 "/TatumQuantizationFitBenchmarkingOutput/sampleValues.csv"
293 };
294 sampleValueCsvFile << sampleValueCsv.rdbuf();
295 }
296
297 // If this changed, then some non-refactoring code change happened. If
298 // `rocInfo.areaUnderCurve > previousAuc`, then there's probably no argument
299 // about the change. On the contrary, though, the change is either an
300 // inadvertent bug, and if it is deliberate, should be well justified.
301 REQUIRE(!classifierQualityHasChanged);
302}
303} // namespace MIR
static ProjectFileIORegistry::AttributeWriterEntry entry
MockedAudio audio
OctaveError GetOctaveError(double expected, double actual)
Gets the tempo detection octave error, as defined in section 5. of Schreiber, H., Urbano,...
float GetChecksum(const MirAudioReader &source)
Definition: MirTestUtils.h:163
std::optional< MusicalMeter > GetMusicalMeterFromSignal(const MirAudioReader &audio, FalsePositiveTolerance tolerance, const std::function< void(double)> &progressCallback, QuantizationFitDebugOutput *debugOutput)
void ProgressBar(int width, int percent)
TEST_CASE("GetBpmFromFilename")
static const std::unordered_map< FalsePositiveTolerance, LoopClassifierSettings > loopClassifierSettings
RocInfo GetRocInfo(std::vector< Result > results, double allowedFalsePositiveRate=0.)
Definition: MirTestUtils.h:52
std::optional< double > GetBpmFromFilename(const std::string &filename)
auto ToString(const std::optional< TimeSignature > &ts)
static constexpr auto runLocally
Definition: MirTestUtils.h:29
__finl float_x4 __vecc sqrt(const float_x4 &a)
OnsetQuantization tatumQuantization
Definition: MirTypes.h:138
std::optional< TimeSignature > timeSignature
Definition: MirTypes.h:140
const double threshold
Definition: MirTestUtils.h:34