50template <
typename Result>
52GetRocInfo(std::vector<Result> results,
double allowedFalsePositiveRate = 0.)
54 const auto truth = std::mem_fn(&Result::truth);
55 const auto falsity = std::not_fn(truth);
58 assert(any_of(results.begin(), results.end(), truth));
59 assert(any_of(results.begin(), results.end(), falsity));
61 assert(allowedFalsePositiveRate >= 0. && allowedFalsePositiveRate <= 1.);
62 allowedFalsePositiveRate = std::clamp(allowedFalsePositiveRate, 0., 1.);
65 std::sort(results.begin(), results.end(), [](
const auto& a,
const auto& b) {
66 return a.score > b.score;
69 const auto size = results.size();
70 const auto numPositives = count_if(results.begin(), results.end(), truth);
71 const auto numNegatives =
size - numPositives;
77 std::vector<double> truePositiveRates;
78 truePositiveRates.reserve(
size);
79 std::vector<double> falsePositiveRates;
80 falsePositiveRates.reserve(
size);
81 size_t numTruePositives = 0;
82 size_t numFalsePositives = 0;
83 for (
const auto& result : results)
89 truePositiveRates.push_back(
90 static_cast<double>(numTruePositives) / numPositives);
91 falsePositiveRates.push_back(
92 static_cast<double>(numFalsePositives) / numNegatives);
99 for (
size_t i = 0; i <=
size; ++i)
101 const auto leftFpr = i == 0 ? 0. : falsePositiveRates[i - 1];
102 const auto rightFpr = i ==
size ? 1. : falsePositiveRates[i];
103 const auto leftTpr = i == 0 ? 0. : truePositiveRates[i - 1];
104 const auto rightTpr = i ==
size ? 1. : truePositiveRates[i];
105 const auto trapezoid = (rightTpr + leftTpr) * (rightFpr - leftFpr) / 2.;
106 assert(trapezoid >= 0);
111 const auto it = std::upper_bound(
112 falsePositiveRates.begin(), falsePositiveRates.end(),
113 allowedFalsePositiveRate);
115 if (it == falsePositiveRates.end())
117 return { auc, results.back().score };
118 else if (it == falsePositiveRates.begin())
120 return { auc, results.front().score };
124 const auto index = it - falsePositiveRates.begin();
125 const auto threshold = (results[index - 1].score + results[index].score) / 2;
127 return { auc, threshold };
134 std::ofstream& ofs,
const std::vector<T>& v,
const char*
name)
136 ofs <<
name <<
" = [";
137 std::for_each(v.begin(), v.end(), [&](T x) { ofs << x <<
","; });
166 float checksum = 0.f;
168 std::array<float, bufferSize> buffer;
171 const auto numSamples =
172 std::min<long long>(bufferSize, source.
GetNumSamples() - start);
175 source.
ReadFloats(buffer.data(), start, numSamples);
177 std::accumulate(buffer.begin(), buffer.begin() + numSamples, 0.f);
virtual void ReadFloats(float *buffer, long long where, size_t numFrames) const =0
virtual long long GetNumSamples() const =0
void PrintPythonVector(std::ofstream &ofs, const std::vector< T > &v, const char *name)
OctaveError GetOctaveError(double expected, double actual)
Gets the tempo detection octave error, as defined in section 5. of Schreiber, H., Urbano,...
float GetChecksum(const MirAudioReader &source)
void ProgressBar(int width, int percent)
RocInfo GetRocInfo(std::vector< Result > results, double allowedFalsePositiveRate=0.)
static constexpr auto runLocally
const double areaUnderCurve