6#include <catch2/catch.hpp>
13#define USE_FILESYSTEM (__has_include(<filesystem>) && _WIN32)
24 std::string(CMAKE_CURRENT_SOURCE_DIR) +
"/benchmarking-dataset";
28 std::vector<std::string> files;
30 namespace fs = std::filesystem;
32 for (
const auto& subEntry : fs::recursive_directory_iterator(
entry))
34 subEntry.is_regular_file() && subEntry.path().extension() ==
".mp3")
35 files.push_back(subEntry.path().string());
40 const auto command =
"find -H " +
datasetRoot +
" -type f -name '*.mp3' -print";
41 FILE* pipe = popen(command.c_str(),
"r");
43 throw std::runtime_error(
"popen() failed!");
44 constexpr auto bufferSize = 512;
45 char buffer[bufferSize];
46 while (fgets(buffer, bufferSize, pipe) !=
nullptr)
48 std::string file(buffer);
49 file.erase(file.find_last_not_of(
"\n") + 1);
50 files.push_back(file);
52 const auto returnCode = pclose(pipe);
54 throw std::runtime_error(
"pclose() failed!");
56 std::sort(files.begin(), files.end());
60std::string
Pretty(
const std::string& filename)
63 const auto datasetRootLength =
datasetRoot.length();
64 auto tmp = filename.substr(datasetRootLength + 1);
66 tmp = tmp.substr(0, tmp.length() - 4);
68 std::replace(tmp.begin(), tmp.end(),
'\\',
'/');
86 using Samples = std::vector<Sample>;
90 double areaUnderCurve;
96 const Samples samples;
97 const double allowedFalsePositiveRate;
98 const Expected expected;
101 const std::vector<TestCase> testCases {
105 TestCase { Samples { {
true, 100. }, {
false, 200. } }, 0.,
106 Expected { 0., 200. } },
110 TestCase { Samples { {
true, 100. }, {
false, 200. } }, 1.,
111 Expected { 0., 100. } },
114 TestCase { Samples { {
false, 100. }, {
false, 150. }, {
true, 200. } },
115 0., Expected { 1., 175. } },
119 Samples { {
false, 1. }, {
true, 2. }, {
false, 3. }, {
true, 4. } },
120 0.5, Expected { .75, 1.5 } },
123 for (
const auto& testCase : testCases)
126 GetRocInfo(testCase.samples, testCase.allowedFalsePositiveRate);
127 REQUIRE(roc.areaUnderCurve == testCase.expected.areaUnderCurve);
128 REQUIRE(roc.threshold == testCase.expected.threshold);
134 constexpr auto bufferSize = 5;
136 REQUIRE(checksum == 0.);
139auto ToString(
const std::optional<TimeSignature>& ts)
145 return std::string(
"2/2");
148 return std::string(
"4/4");
150 return std::string(
"3/4");
152 return std::string(
"6/8");
154 return std::string(
"none");
157 return std::string(
"none");
189 constexpr int progressBarWidth = 50;
191 std::stringstream sampleValueCsv;
193 <<
"truth,score,tatumRate,bpm,ts,octaveFactor,octaveError,lag,filename\n";
195 float checksum = 0.f;
200 std::optional<OctaveError> octaveError;
202 std::vector<Sample> samples;
203 const auto numFiles = audioFiles.size();
205 std::chrono::milliseconds computationTime { 0 };
207 audioFiles.begin(), audioFiles.begin() + numFiles,
208 std::back_inserter(samples), [&](
const std::string& wavFile) {
209 const WavMirAudioReader audio { wavFile };
212 std::function<void(
double)> progressCb;
213 const auto now = std::chrono::steady_clock::now();
216 std::chrono::duration_cast<std::chrono::milliseconds>(
217 std::chrono::steady_clock::now() - now);
218 ProgressBar(progressBarWidth, 100 * count++ / numFiles);
220 const auto truth = expected.has_value();
221 const std::optional<OctaveError> error =
222 truth && debugOutput.
bpm > 0 ?
225 sampleValueCsv << (truth ?
"true" :
"false") <<
","
226 << debugOutput.
score <<
","
229 <<
"," << debugOutput.
bpm <<
","
231 << (error.has_value() ? error->factor : 0.) <<
","
232 << (error.has_value() ? error->remainder : 0.) <<
","
234 <<
Pretty(wavFile) <<
"\n";
235 return Sample { truth, debugOutput.
score, error };
239 std::ofstream timeMeasurementFile {
"./timeMeasurement.txt" };
240 timeMeasurementFile << computationTime.count() <<
"ms\n";
247 const auto strictThreshold =
250 .allowedFalsePositiveRate)
254 const auto octaveErrors = std::accumulate(
255 samples.begin(), samples.end(), std::vector<double> {},
256 [&](std::vector<double> octaveErrors,
const Sample& sample)
258 if (sample.octaveError.has_value())
259 octaveErrors.push_back(sample.octaveError->remainder);
264 octaveErrors.begin(), octaveErrors.end(), 0.,
265 [&](
double sum,
double octaveError)
266 { return sum + octaveError * octaveError; }) /
267 octaveErrors.size());
269 constexpr auto previousAuc = 0.9312244897959182;
270 const auto classifierQualityHasChanged =
271 std::abs(rocInfo.areaUnderCurve - previousAuc) >= 0.01;
274 if (classifierQualityHasChanged)
276 std::ofstream summaryFile {
277 std::string(CMAKE_CURRENT_SOURCE_DIR) +
278 "/TatumQuantizationFitBenchmarkingOutput/summary.txt"
280 summaryFile << std::setprecision(
281 std::numeric_limits<double>::digits10 + 1)
282 <<
"AUC: " << rocInfo.areaUnderCurve <<
"\n"
283 <<
"Strict Threshold (Minutes-and-Seconds): "
284 << strictThreshold <<
"\n"
285 <<
"Lenient Threshold (Beats-and-Measures): "
286 << rocInfo.threshold <<
"\n"
287 <<
"Octave error RMS: " << octaveErrorStd <<
"\n"
288 <<
"Audio file checksum: " << checksum <<
"\n";
290 std::ofstream sampleValueCsvFile {
291 std::string(CMAKE_CURRENT_SOURCE_DIR) +
292 "/TatumQuantizationFitBenchmarkingOutput/sampleValues.csv"
294 sampleValueCsvFile << sampleValueCsv.rdbuf();
301 REQUIRE(!classifierQualityHasChanged);
static ProjectFileIORegistry::AttributeWriterEntry entry
std::string Pretty(const std::string &filename)
std::vector< std::string > GetBenchmarkingAudioFiles()
OctaveError GetOctaveError(double expected, double actual)
Gets the tempo detection octave error, as defined in section 5. of Schreiber, H., Urbano,...
float GetChecksum(const MirAudioReader &source)
std::optional< MusicalMeter > GetMusicalMeterFromSignal(const MirAudioReader &audio, FalsePositiveTolerance tolerance, const std::function< void(double)> &progressCallback, QuantizationFitDebugOutput *debugOutput)
void ProgressBar(int width, int percent)
TEST_CASE("GetBpmFromFilename")
static const std::unordered_map< FalsePositiveTolerance, LoopClassifierSettings > loopClassifierSettings
RocInfo GetRocInfo(std::vector< Result > results, double allowedFalsePositiveRate=0.)
std::optional< double > GetBpmFromFilename(const std::string &filename)
auto ToString(const std::optional< TimeSignature > &ts)
static constexpr auto runLocally
__finl float_x4 __vecc sqrt(const float_x4 &a)
OnsetQuantization tatumQuantization
std::optional< TimeSignature > timeSignature