diff options
-rw-r--r-- | simpl/partial_tracking.pxd | 3 | ||||
-rw-r--r-- | simpl/partial_tracking.pyx | 12 | ||||
-rw-r--r-- | src/mq/mq.cpp | 49 | ||||
-rw-r--r-- | src/mq/mq.h | 107 | ||||
-rw-r--r-- | src/simpl/partial_tracking.cpp | 82 | ||||
-rw-r--r-- | src/simpl/partial_tracking.h | 18 | ||||
-rw-r--r-- | tests/test_partial_tracking.cpp | 103 |
7 files changed, 316 insertions, 58 deletions
diff --git a/simpl/partial_tracking.pxd b/simpl/partial_tracking.pxd index ce7c385..3c49296 100644 --- a/simpl/partial_tracking.pxd +++ b/simpl/partial_tracking.pxd @@ -26,6 +26,9 @@ cdef extern from "../src/simpl/partial_tracking.h" namespace "simpl": vector[c_Peak*] update_partials(c_Frame* frame) vector[c_Frame*] find_partials(vector[c_Frame*] frames) + cdef cppclass c_MQPartialTracking "simpl::MQPartialTracking"(c_PartialTracking): + c_MQPartialTracking() + cdef cppclass c_SMSPartialTracking "simpl::SMSPartialTracking"(c_PartialTracking): c_SMSPartialTracking() bool realtime() diff --git a/simpl/partial_tracking.pyx b/simpl/partial_tracking.pyx index 87387a7..8dd6e79 100644 --- a/simpl/partial_tracking.pyx +++ b/simpl/partial_tracking.pyx @@ -59,6 +59,18 @@ cdef class PartialTracking: return partial_frames +cdef class MQPartialTracking(PartialTracking): + def __cinit__(self): + if self.thisptr: + del self.thisptr + self.thisptr = new c_MQPartialTracking() + + def __dealloc__(self): + if self.thisptr: + del self.thisptr + self.thisptr = <c_PartialTracking*>0 + + cdef class SMSPartialTracking(PartialTracking): def __cinit__(self): if self.thisptr: diff --git a/src/mq/mq.cpp b/src/mq/mq.cpp index 8e2166d..694d670 100644 --- a/src/mq/mq.cpp +++ b/src/mq/mq.cpp @@ -23,11 +23,7 @@ void hamming_window(int window_size, sample* window) { int simpl::init_mq(MQParameters* params) { // allocate memory for window - params->window = (sample*) malloc(sizeof(sample) * params->frame_size); - int i; - for(i = 0; i < params->frame_size; i++) { - params->window[i] = 1.0; - } + params->window = new sample[params->frame_size]; hamming_window(params->frame_size, params->window); // allocate memory for FFT @@ -48,9 +44,9 @@ void simpl::reset_mq(MQParameters* params) { int simpl::destroy_mq(MQParameters* params) { if(params) { - if(params->window) free(params->window); - if(params->fft_in) free(params->fft_in); - if(params->fft_out) free(params->fft_out); + if(params->window) delete [] params->window; + if(params->fft_in) fftw_free(params->fft_in); + if(params->fft_out) fftw_free(params->fft_out); fftw_destroy_plan(params->fft_plan); params->window = NULL; @@ -65,7 +61,7 @@ int simpl::destroy_mq(MQParameters* params) { // Add new_peak to the doubly linked list of peaks, keeping peaks sorted // with the largest amplitude peaks at the start of the list -void add_peak(MQPeak* new_peak, MQPeakList* peak_list) { +void simpl::mq_add_peak(MQPeak* new_peak, MQPeakList* peak_list) { do { if(peak_list->peak) { if(peak_list->peak->amplitude > new_peak->amplitude) { @@ -73,8 +69,7 @@ void add_peak(MQPeak* new_peak, MQPeakList* peak_list) { peak_list = peak_list->next; } else { - MQPeakList* new_node = - (MQPeakList*)malloc(sizeof(MQPeakList)); + MQPeakList* new_node = new MQPeakList(); new_node->peak = new_peak; new_node->prev = peak_list; new_node->next = NULL; @@ -83,7 +78,7 @@ void add_peak(MQPeak* new_peak, MQPeakList* peak_list) { } } else { - MQPeakList* new_node = (MQPeakList*)malloc(sizeof(MQPeakList)); + MQPeakList* new_node = new MQPeakList(); new_node->peak = peak_list->peak; new_node->prev = peak_list; new_node->next = peak_list->next; @@ -101,18 +96,22 @@ void add_peak(MQPeak* new_peak, MQPeakList* peak_list) { while(1); } -// delete the given PeakList void simpl::delete_peak_list(MQPeakList* peak_list) { - // destroy list of peaks while(peak_list && peak_list->next) { if(peak_list->peak) { - free(peak_list->peak); + delete peak_list->peak; + peak_list->peak = NULL; } MQPeakList* temp = peak_list->next; - free(peak_list); + delete peak_list; peak_list = temp; } - free(peak_list); + if(peak_list) { + peak_list->next = NULL; + peak_list->prev = NULL; + delete peak_list; + } + peak_list = NULL; } sample get_magnitude(sample x, sample y) { @@ -125,17 +124,16 @@ sample get_phase(sample x, sample y) { MQPeakList* simpl::mq_find_peaks(int signal_size, sample* signal, MQParameters* params) { - int i; int num_peaks = 0; sample prev_amp, current_amp, next_amp; - MQPeakList* peak_list = (MQPeakList*)malloc(sizeof(MQPeakList)); + MQPeakList* peak_list = new MQPeakList(); peak_list->next = NULL; peak_list->prev = NULL; peak_list->peak = NULL; // take fft of the signal memcpy(params->fft_in, signal, sizeof(sample)*params->frame_size); - for(i = 0; i < params->frame_size; i++) { + for(int i = 0; i < params->frame_size; i++) { params->fft_in[i] *= params->window[i]; } fftw_execute(params->fft_plan); @@ -145,15 +143,14 @@ MQPeakList* simpl::mq_find_peaks(int signal_size, sample* signal, current_amp = get_magnitude(params->fft_out[1][0], params->fft_out[1][1]); // find all peaks in the amplitude spectrum - for(i = 1; i < params->num_bins - 1; i++) { + for(int i = 1; i < params->num_bins - 1; i++) { next_amp = get_magnitude(params->fft_out[i+1][0], params->fft_out[i+1][1]); if((current_amp > prev_amp) && (current_amp > next_amp) && (current_amp > params->peak_threshold)) { - // create a new MQPeak - MQPeak* p = (MQPeak*)malloc(sizeof(MQPeak)); + MQPeak* p = new MQPeak(); p->amplitude = current_amp; p->frequency = i * params->fundamental; p->phase = get_phase(params->fft_out[i][0], params->fft_out[i][1]); @@ -162,7 +159,7 @@ MQPeakList* simpl::mq_find_peaks(int signal_size, sample* signal, p->prev = NULL; // add it to the appropriate position in the list of Peaks - add_peak(p, peak_list); + mq_add_peak(p, peak_list); num_peaks++; } prev_amp = current_amp; @@ -172,7 +169,7 @@ MQPeakList* simpl::mq_find_peaks(int signal_size, sample* signal, // limit peaks to a maximum of max_peaks if(num_peaks > params->max_peaks) { MQPeakList* current = peak_list; - for(i = 0; i < params->max_peaks-1; i++) { + for(int i = 0; i < params->max_peaks-1; i++) { current = current->next; } @@ -303,7 +300,7 @@ MQPeakList* simpl::mq_sort_peaks_by_frequency(MQPeakList* peak_list, // Find a candidate match for peak in frame if one exists. This is the closest // (in frequency) match that is within the matching interval. MQPeak* find_closest_match(MQPeak* p, MQPeakList* peak_list, - MQParameters* params, int backwards) { + MQParameters* params, int backwards) { MQPeakList* current = peak_list; MQPeak* match = NULL; sample best_distance = 44100.0; diff --git a/src/mq/mq.h b/src/mq/mq.h index d2f24cc..0c6a11c 100644 --- a/src/mq/mq.h +++ b/src/mq/mq.h @@ -7,44 +7,91 @@ #include <math.h> #include <string.h> +#include "base.h" + namespace simpl { -typedef double sample; - -typedef struct MQPeak { - float amplitude; - float frequency; - float phase; - int bin; - struct MQPeak* next; - struct MQPeak* prev; -} MQPeak; - -typedef struct MQPeakList { - struct MQPeakList* next; - struct MQPeakList* prev; - struct MQPeak* peak; -} MQPeakList; - -typedef struct MQParameters { - int frame_size; - int max_peaks; - int num_bins; - sample peak_threshold; - sample fundamental; - sample matching_interval; - sample* window; - sample* fft_in; - fftw_complex* fft_out; - fftw_plan fft_plan; - MQPeakList* prev_peaks; -} MQParameters; +// --------------------------------------------------------------------------- +// MQPeak +// --------------------------------------------------------------------------- +class MQPeak { + public: + float amplitude; + float frequency; + float phase; + int bin; + MQPeak* next; + MQPeak* prev; + + MQPeak() { + amplitude = 0.f; + frequency = 0.f; + phase = 0.f; + bin = 0; + next = NULL; + prev = NULL; + } +}; + + +// --------------------------------------------------------------------------- +// MQPeakList +// --------------------------------------------------------------------------- +class MQPeakList { + public: + MQPeakList* next; + MQPeakList* prev; + MQPeak* peak; + + MQPeakList() { + next = NULL; + prev = NULL; + peak = NULL; + } +}; + + +// --------------------------------------------------------------------------- +// MQParameters +// --------------------------------------------------------------------------- +class MQParameters { + public: + int frame_size; + int max_peaks; + int num_bins; + sample peak_threshold; + sample fundamental; + sample matching_interval; + sample* window; + sample* fft_in; + fftw_complex* fft_out; + fftw_plan fft_plan; + MQPeakList* prev_peaks; + + MQParameters() { + frame_size = 0; + max_peaks = 0; + num_bins = 0; + peak_threshold = 0.f; + fundamental = 0.f; + matching_interval = 0.f; + window = NULL; + fft_in = NULL; + fft_out = NULL; + prev_peaks = NULL; + } +}; + +// --------------------------------------------------------------------------- +// MQ functions +// --------------------------------------------------------------------------- int init_mq(MQParameters* params); void reset_mq(MQParameters* params); int destroy_mq(MQParameters* params); +void mq_add_peak(MQPeak* new_peak, MQPeakList* peak_list); void delete_peak_list(MQPeakList* peak_list); MQPeakList* mq_sort_peaks_by_frequency(MQPeakList* peak_list, int num_peaks); diff --git a/src/simpl/partial_tracking.cpp b/src/simpl/partial_tracking.cpp index 420bcb6..c9f7dab 100644 --- a/src/simpl/partial_tracking.cpp +++ b/src/simpl/partial_tracking.cpp @@ -79,6 +79,87 @@ Frames PartialTracking::find_partials(Frames frames) { // --------------------------------------------------------------------------- +// MQPartialTracking +// --------------------------------------------------------------------------- + +MQPartialTracking::MQPartialTracking() { + _mq_params.max_peaks = _max_partials; + _mq_params.frame_size = 0; + _mq_params.num_bins = 0; + _mq_params.peak_threshold = 0.0; + _mq_params.matching_interval = 100.0; + _mq_params.fundamental = 0; + init_mq(&_mq_params); + _peak_list = NULL; + _prev_peak_list = NULL; +} + +MQPartialTracking::~MQPartialTracking() { + destroy_mq(&_mq_params); + delete_peak_list(_peak_list); + _prev_peak_list = NULL; +} + +void MQPartialTracking::reset() { + reset_mq(&_mq_params); + delete_peak_list(_peak_list); + _prev_peak_list = NULL; +} + +void MQPartialTracking::max_partials(int new_max_partials) { + _max_partials = new_max_partials; + _mq_params.max_peaks = _max_partials; +} + +Peaks MQPartialTracking::update_partials(Frame* frame) { + Peaks peaks; + int num_peaks = _max_partials; + if(num_peaks > frame->num_peaks()) { + num_peaks = frame->num_peaks(); + } + frame->clear_partials(); + + _peak_list = new MQPeakList(); + for(int i = 0; i < num_peaks; i++) { + MQPeak* p = new MQPeak(); + p->amplitude = frame->peak(i)->amplitude; + p->frequency = frame->peak(i)->frequency; + p->phase = frame->peak(i)->phase; + p->bin = i; + p->next = NULL; + p->prev = NULL; + mq_add_peak(p, _peak_list); + } + + MQPeakList* partials = mq_track_peaks(_peak_list, &_mq_params); + partials = mq_sort_peaks_by_frequency(partials, num_peaks); + + int num_partials = 0; + while(partials && partials->peak && (num_partials < _max_partials)) { + Peak* p = new Peak(); + p->amplitude = partials->peak->amplitude; + p->frequency = partials->peak->frequency; + p->phase = partials->peak->phase; + peaks.push_back(p); + frame->add_partial(p); + + partials = partials->next; + num_partials++; + } + + for(int i = num_partials; i < _max_partials; i++) { + Peak* p = new Peak(); + peaks.push_back(p); + frame->add_partial(p); + } + + delete_peak_list(_prev_peak_list); + _prev_peak_list = _peak_list; + return peaks; +} + + +// --------------------------------------------------------------------------- // SMSPartialTracking // --------------------------------------------------------------------------- @@ -368,6 +449,7 @@ Peaks SndObjPartialTracking::update_partials(Frame* frame) { if(num_peaks > frame->num_peaks()) { num_peaks = frame->num_peaks(); } + frame->clear_partials(); for(int i = 0; i < num_peaks; i++) { _peak_amplitude[i] = frame->peak(i)->amplitude; diff --git a/src/simpl/partial_tracking.h b/src/simpl/partial_tracking.h index 0812e64..1cf022d 100644 --- a/src/simpl/partial_tracking.h +++ b/src/simpl/partial_tracking.h @@ -3,6 +3,8 @@ #include "base.h" +#include "mq.h" + extern "C" { #include "sms.h" } @@ -60,6 +62,22 @@ class PartialTracking { virtual Frames find_partials(Frames frames); }; +// --------------------------------------------------------------------------- +// MQPartialTracking +// --------------------------------------------------------------------------- +class MQPartialTracking : public PartialTracking { + private: + MQParameters _mq_params; + MQPeakList* _peak_list; + MQPeakList* _prev_peak_list; + + public: + MQPartialTracking(); + ~MQPartialTracking(); + void reset(); + void max_partials(int new_max_partials); + Peaks update_partials(Frame* frame); +}; // --------------------------------------------------------------------------- // SMSPartialTracking diff --git a/tests/test_partial_tracking.cpp b/tests/test_partial_tracking.cpp index f5c1fff..86cc878 100644 --- a/tests/test_partial_tracking.cpp +++ b/tests/test_partial_tracking.cpp @@ -15,6 +15,100 @@ namespace simpl { // --------------------------------------------------------------------------- +// TestMQPartialTracking +// --------------------------------------------------------------------------- +class TestMQPartialTracking : public CPPUNIT_NS::TestCase { + CPPUNIT_TEST_SUITE(TestMQPartialTracking); + CPPUNIT_TEST(test_basic); + CPPUNIT_TEST(test_peaks); + CPPUNIT_TEST_SUITE_END(); + +protected: + static const double PRECISION = 0.001; + MQPeakDetection* pd; + MQPartialTracking* pt; + SndfileHandle sf; + int num_samples; + + void test_basic() { + pt->reset(); + pd->hop_size(256); + pd->frame_size(2048); + + sample* audio = new sample[(int)sf.frames()]; + sf.read(audio, (int)sf.frames()); + + Frames frames = pd->find_peaks( + num_samples, &(audio[(int)sf.frames() / 2]) + ); + frames = pt->find_partials(frames); + + for(int i = 0; i < frames.size(); i++) { + CPPUNIT_ASSERT(frames[i]->num_peaks() > 0); + CPPUNIT_ASSERT(frames[i]->num_partials() > 0); + } + } + + void test_peaks() { + pt->reset(); + + Frames frames; + Peaks peaks; + int num_frames = 8; + + for(int i = 0; i < num_frames; i++) { + Peak* p = new Peak(); + p->amplitude = 0.2; + p->frequency = 220; + + Peak* p2 = new Peak(); + p2->amplitude = 0.2; + p2->frequency = 440; + + Frame* f = new Frame(); + f->add_peak(p); + f->add_peak(p2); + + frames.push_back(f); + peaks.push_back(p); + peaks.push_back(p2); + } + + pt->find_partials(frames); + for(int i = 0; i < num_frames; i++) { + CPPUNIT_ASSERT(frames[i]->num_peaks() > 0); + CPPUNIT_ASSERT(frames[i]->num_partials() > 0); + CPPUNIT_ASSERT(frames[i]->partial(0)->amplitude == 0.2); + CPPUNIT_ASSERT(frames[i]->partial(0)->frequency == 220); + CPPUNIT_ASSERT(frames[i]->partial(1)->amplitude == 0.2); + CPPUNIT_ASSERT(frames[i]->partial(1)->frequency == 440); + } + + for(int i = 0; i < num_frames * 2; i++) { + delete peaks[i]; + } + + for(int i = 0; i < num_frames; i++) { + delete frames[i]; + } + } + +public: + void setUp() { + pd = new MQPeakDetection(); + pt = new MQPartialTracking(); + sf = SndfileHandle("../tests/audio/flute.wav"); + num_samples = 4096; + } + + void tearDown() { + delete pd; + delete pt; + } +}; + + +// --------------------------------------------------------------------------- // TestSMSPartialTracking // --------------------------------------------------------------------------- class TestSMSPartialTracking : public CPPUNIT_NS::TestCase { @@ -38,7 +132,9 @@ protected: sample* audio = new sample[(int)sf.frames()]; sf.read(audio, (int)sf.frames()); - Frames frames = pd->find_peaks(num_samples, &(audio[(int)sf.frames() / 2])); + Frames frames = pd->find_peaks( + num_samples, &(audio[(int)sf.frames() / 2]) + ); frames = pt->find_partials(frames); for(int i = 0; i < frames.size(); i++) { @@ -130,7 +226,9 @@ protected: sample* audio = new sample[(int)sf.frames()]; sf.read(audio, (int)sf.frames()); - Frames frames = pd->find_peaks(num_samples, &(audio[(int)sf.frames() / 2])); + Frames frames = pd->find_peaks( + num_samples, &(audio[(int)sf.frames() / 2]) + ); frames = pt->find_partials(frames); for(int i = 0; i < frames.size(); i++) { @@ -199,6 +297,7 @@ public: } // end of namespace simpl +CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestMQPartialTracking); CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestSMSPartialTracking); CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestLorisPartialTracking); |