diff options
-rw-r--r-- | src/mq/twm.cpp | 119 | ||||
-rw-r--r-- | src/mq/twm.h | 19 | ||||
-rw-r--r-- | src/simpl/peak_detection.h | 1 | ||||
-rw-r--r-- | tests/test_peak_detection.cpp | 22 | ||||
-rw-r--r-- | tests/test_peak_detection.h | 13 | ||||
-rw-r--r-- | tests/tests.cpp | 1 |
6 files changed, 175 insertions, 0 deletions
diff --git a/src/mq/twm.cpp b/src/mq/twm.cpp new file mode 100644 index 0000000..b8a260d --- /dev/null +++ b/src/mq/twm.cpp @@ -0,0 +1,119 @@ +#include "math.h" +#include "twm.h" + +using namespace simpl; + + +int simpl::best_match(sample freq, std::vector<sample> candidates) { + sample best_diff = 22050.0; + sample diff = 0.0; + int best = 0; + + for(int i = 0; i < candidates.size(); i++) { + diff = fabs(freq - candidates[i]); + if(diff < best_diff) { + best_diff = diff; + best = i; + } + } + + return best; +} + +sample simpl::twm(Peaks peaks, sample f_min, sample f_max, sample f_step) { + sample p = 0.5; + sample q = 1.4; + sample r = 0.5; + sample rho = 0.33; + int N = 30; + std::map<sample, sample> err; + + if(peaks.size() == 0) { + return 0.0; + } + + sample max_amp = 0.0; + for(int i = 0; i < peaks.size(); i++) { + if(peaks[i]->amplitude > max_amp) { + max_amp = peaks[i]->amplitude; + } + } + + if(max_amp == 0) { + return 0.0; + } + + // remove all peaks with amplitude of less than 10% of max + // note: this is not in the TWM paper, found that it improved + // accuracy however + for(int i = 0; i < peaks.size(); i++) { + if(peaks[i]->amplitude < (max_amp * 0.1)) { + peaks.erase(peaks.begin() + i); + } + } + + // get the max frequency of the remaining peaks + sample max_freq = 0.0; + for(int i = 0; i < peaks.size(); i++) { + if(peaks[i]->frequency > max_freq) { + max_freq = peaks[i]->frequency; + } + } + + std::vector<sample> peak_freqs; + for(int i = 0; i < peaks.size(); i++) { + peak_freqs.push_back(peaks[i]->frequency); + } + + sample f_current = f_min; + while(f_current < f_max) { + sample err_pm = 0.0; + sample err_mp = 0.0; + std::vector<sample> harmonics; + + for(sample f = f_current; f <= f_max; f += f_current) { + harmonics.push_back(f); + if(harmonics.size() >= N) { + break; + } + } + + // calculate mismatch between predicted and actual peaks + for(int i = 0; i < harmonics.size(); i++) { + sample h = harmonics[i]; + int k = best_match(h, peak_freqs); + sample f = peaks[k]->frequency; + sample a = peaks[k]->amplitude; + err_pm += (fabs(h - f) * pow(h, -p)) + + (((a / max_amp) * (q * fabs(h - f)) * (pow(h, -p) - r))); + } + + // calculate the mismatch between actual and predicted peaks + for(int i = 0; i < peaks.size(); i++) { + sample f = peaks[i]->frequency; + sample a = peaks[i]->amplitude; + int k = best_match(f, harmonics); + sample h = harmonics[k]; + err_mp += (fabs(f - h) * pow(f, -p)) + + ((a / max_amp) * (q * fabs(f - h)) * (pow(f, -p) - r)); + } + + // calculate the total error for f_current as a fundamental frequency + err[f_current] = (err_pm / harmonics.size()) + + (rho * err_mp / peaks.size()); + + f_current += f_step; + } + + // return the value with the minimum total error + sample best_freq = 0; + sample min_error = 22050; + for(std::map<sample, sample>::iterator i = err.begin(); i != err.end(); i++) { + if(fabs((*i).second) < min_error) { + min_error = fabs((*i).second); + best_freq = (*i).first; + } + } + + return best_freq; +} diff --git a/src/mq/twm.h b/src/mq/twm.h new file mode 100644 index 0000000..e266861 --- /dev/null +++ b/src/mq/twm.h @@ -0,0 +1,19 @@ +#ifndef SIMPL_TWM_H +#define SIMPL_TWM_H + +#include <map> +#include <vector> + +#include "base.h" + +namespace simpl +{ + +int best_match(sample freq, std::vector<sample> candidates); + +sample twm(Peaks peaks, sample f_min=20.0, + sample f_max=3000.0, sample f_step=10.0); + +} + +#endif diff --git a/src/simpl/peak_detection.h b/src/simpl/peak_detection.h index 22af561..02d8988 100644 --- a/src/simpl/peak_detection.h +++ b/src/simpl/peak_detection.h @@ -4,6 +4,7 @@ #include "base.h" #include "mq.h" +#include "twm.h" extern "C" { #include "sms.h" diff --git a/tests/test_peak_detection.cpp b/tests/test_peak_detection.cpp index 3cb9c1a..e0eba9e 100644 --- a/tests/test_peak_detection.cpp +++ b/tests/test_peak_detection.cpp @@ -72,6 +72,28 @@ void TestMQPeakDetection::test_find_peaks_change_hop_frame_size() { // --------------------------------------------------------------------------- +// TestTWM +// --------------------------------------------------------------------------- +void TestTWM::test_basic() { + int num_peaks = 100; + int base_freq = 110; + Peaks peaks; + + for(int i = 0; i < num_peaks; i++) { + Peak* p = new Peak(); + p->amplitude = 0.4; + p->frequency = base_freq * (i + 1); + peaks.push_back(p); + } + + CPPUNIT_ASSERT_DOUBLES_EQUAL(base_freq, twm(peaks), PRECISION); + + for(int i = 0; i < num_peaks; i++) { + delete peaks[i]; + } +} + +// --------------------------------------------------------------------------- // TestLorisPeakDetection // --------------------------------------------------------------------------- void TestLorisPeakDetection::setUp() { diff --git a/tests/test_peak_detection.h b/tests/test_peak_detection.h index 5b36f19..2b574cb 100644 --- a/tests/test_peak_detection.h +++ b/tests/test_peak_detection.h @@ -37,6 +37,19 @@ protected: // --------------------------------------------------------------------------- +// TestTWM +// --------------------------------------------------------------------------- +class TestTWM : public CPPUNIT_NS::TestCase { + CPPUNIT_TEST_SUITE(TestTWM); + CPPUNIT_TEST(test_basic); + CPPUNIT_TEST_SUITE_END(); + +protected: + void test_basic(); +}; + + +// --------------------------------------------------------------------------- // TestLorisPeakDetection // --------------------------------------------------------------------------- class TestLorisPeakDetection : public CPPUNIT_NS::TestCase { diff --git a/tests/tests.cpp b/tests/tests.cpp index 133c873..d185923 100644 --- a/tests/tests.cpp +++ b/tests/tests.cpp @@ -10,6 +10,7 @@ CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestPeak); CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestFrame); CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestMQPeakDetection); +CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestTWM); CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestLorisPeakDetection); CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestMQPartialTracking); CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestSMSPartialTracking); |