summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--simpl/partial_tracking.pxd3
-rw-r--r--simpl/partial_tracking.pyx12
-rw-r--r--src/mq/mq.cpp49
-rw-r--r--src/mq/mq.h107
-rw-r--r--src/simpl/partial_tracking.cpp82
-rw-r--r--src/simpl/partial_tracking.h18
-rw-r--r--tests/test_partial_tracking.cpp103
7 files changed, 316 insertions, 58 deletions
diff --git a/simpl/partial_tracking.pxd b/simpl/partial_tracking.pxd
index ce7c385..3c49296 100644
--- a/simpl/partial_tracking.pxd
+++ b/simpl/partial_tracking.pxd
@@ -26,6 +26,9 @@ cdef extern from "../src/simpl/partial_tracking.h" namespace "simpl":
vector[c_Peak*] update_partials(c_Frame* frame)
vector[c_Frame*] find_partials(vector[c_Frame*] frames)
+ cdef cppclass c_MQPartialTracking "simpl::MQPartialTracking"(c_PartialTracking):
+ c_MQPartialTracking()
+
cdef cppclass c_SMSPartialTracking "simpl::SMSPartialTracking"(c_PartialTracking):
c_SMSPartialTracking()
bool realtime()
diff --git a/simpl/partial_tracking.pyx b/simpl/partial_tracking.pyx
index 87387a7..8dd6e79 100644
--- a/simpl/partial_tracking.pyx
+++ b/simpl/partial_tracking.pyx
@@ -59,6 +59,18 @@ cdef class PartialTracking:
return partial_frames
+cdef class MQPartialTracking(PartialTracking):
+ def __cinit__(self):
+ if self.thisptr:
+ del self.thisptr
+ self.thisptr = new c_MQPartialTracking()
+
+ def __dealloc__(self):
+ if self.thisptr:
+ del self.thisptr
+ self.thisptr = <c_PartialTracking*>0
+
+
cdef class SMSPartialTracking(PartialTracking):
def __cinit__(self):
if self.thisptr:
diff --git a/src/mq/mq.cpp b/src/mq/mq.cpp
index 8e2166d..694d670 100644
--- a/src/mq/mq.cpp
+++ b/src/mq/mq.cpp
@@ -23,11 +23,7 @@ void hamming_window(int window_size, sample* window) {
int simpl::init_mq(MQParameters* params) {
// allocate memory for window
- params->window = (sample*) malloc(sizeof(sample) * params->frame_size);
- int i;
- for(i = 0; i < params->frame_size; i++) {
- params->window[i] = 1.0;
- }
+ params->window = new sample[params->frame_size];
hamming_window(params->frame_size, params->window);
// allocate memory for FFT
@@ -48,9 +44,9 @@ void simpl::reset_mq(MQParameters* params) {
int simpl::destroy_mq(MQParameters* params) {
if(params) {
- if(params->window) free(params->window);
- if(params->fft_in) free(params->fft_in);
- if(params->fft_out) free(params->fft_out);
+ if(params->window) delete [] params->window;
+ if(params->fft_in) fftw_free(params->fft_in);
+ if(params->fft_out) fftw_free(params->fft_out);
fftw_destroy_plan(params->fft_plan);
params->window = NULL;
@@ -65,7 +61,7 @@ int simpl::destroy_mq(MQParameters* params) {
// Add new_peak to the doubly linked list of peaks, keeping peaks sorted
// with the largest amplitude peaks at the start of the list
-void add_peak(MQPeak* new_peak, MQPeakList* peak_list) {
+void simpl::mq_add_peak(MQPeak* new_peak, MQPeakList* peak_list) {
do {
if(peak_list->peak) {
if(peak_list->peak->amplitude > new_peak->amplitude) {
@@ -73,8 +69,7 @@ void add_peak(MQPeak* new_peak, MQPeakList* peak_list) {
peak_list = peak_list->next;
}
else {
- MQPeakList* new_node =
- (MQPeakList*)malloc(sizeof(MQPeakList));
+ MQPeakList* new_node = new MQPeakList();
new_node->peak = new_peak;
new_node->prev = peak_list;
new_node->next = NULL;
@@ -83,7 +78,7 @@ void add_peak(MQPeak* new_peak, MQPeakList* peak_list) {
}
}
else {
- MQPeakList* new_node = (MQPeakList*)malloc(sizeof(MQPeakList));
+ MQPeakList* new_node = new MQPeakList();
new_node->peak = peak_list->peak;
new_node->prev = peak_list;
new_node->next = peak_list->next;
@@ -101,18 +96,22 @@ void add_peak(MQPeak* new_peak, MQPeakList* peak_list) {
while(1);
}
-// delete the given PeakList
void simpl::delete_peak_list(MQPeakList* peak_list) {
- // destroy list of peaks
while(peak_list && peak_list->next) {
if(peak_list->peak) {
- free(peak_list->peak);
+ delete peak_list->peak;
+ peak_list->peak = NULL;
}
MQPeakList* temp = peak_list->next;
- free(peak_list);
+ delete peak_list;
peak_list = temp;
}
- free(peak_list);
+ if(peak_list) {
+ peak_list->next = NULL;
+ peak_list->prev = NULL;
+ delete peak_list;
+ }
+ peak_list = NULL;
}
sample get_magnitude(sample x, sample y) {
@@ -125,17 +124,16 @@ sample get_phase(sample x, sample y) {
MQPeakList* simpl::mq_find_peaks(int signal_size, sample* signal,
MQParameters* params) {
- int i;
int num_peaks = 0;
sample prev_amp, current_amp, next_amp;
- MQPeakList* peak_list = (MQPeakList*)malloc(sizeof(MQPeakList));
+ MQPeakList* peak_list = new MQPeakList();
peak_list->next = NULL;
peak_list->prev = NULL;
peak_list->peak = NULL;
// take fft of the signal
memcpy(params->fft_in, signal, sizeof(sample)*params->frame_size);
- for(i = 0; i < params->frame_size; i++) {
+ for(int i = 0; i < params->frame_size; i++) {
params->fft_in[i] *= params->window[i];
}
fftw_execute(params->fft_plan);
@@ -145,15 +143,14 @@ MQPeakList* simpl::mq_find_peaks(int signal_size, sample* signal,
current_amp = get_magnitude(params->fft_out[1][0], params->fft_out[1][1]);
// find all peaks in the amplitude spectrum
- for(i = 1; i < params->num_bins - 1; i++) {
+ for(int i = 1; i < params->num_bins - 1; i++) {
next_amp = get_magnitude(params->fft_out[i+1][0],
params->fft_out[i+1][1]);
if((current_amp > prev_amp) &&
(current_amp > next_amp) &&
(current_amp > params->peak_threshold)) {
- // create a new MQPeak
- MQPeak* p = (MQPeak*)malloc(sizeof(MQPeak));
+ MQPeak* p = new MQPeak();
p->amplitude = current_amp;
p->frequency = i * params->fundamental;
p->phase = get_phase(params->fft_out[i][0], params->fft_out[i][1]);
@@ -162,7 +159,7 @@ MQPeakList* simpl::mq_find_peaks(int signal_size, sample* signal,
p->prev = NULL;
// add it to the appropriate position in the list of Peaks
- add_peak(p, peak_list);
+ mq_add_peak(p, peak_list);
num_peaks++;
}
prev_amp = current_amp;
@@ -172,7 +169,7 @@ MQPeakList* simpl::mq_find_peaks(int signal_size, sample* signal,
// limit peaks to a maximum of max_peaks
if(num_peaks > params->max_peaks) {
MQPeakList* current = peak_list;
- for(i = 0; i < params->max_peaks-1; i++) {
+ for(int i = 0; i < params->max_peaks-1; i++) {
current = current->next;
}
@@ -303,7 +300,7 @@ MQPeakList* simpl::mq_sort_peaks_by_frequency(MQPeakList* peak_list,
// Find a candidate match for peak in frame if one exists. This is the closest
// (in frequency) match that is within the matching interval.
MQPeak* find_closest_match(MQPeak* p, MQPeakList* peak_list,
- MQParameters* params, int backwards) {
+ MQParameters* params, int backwards) {
MQPeakList* current = peak_list;
MQPeak* match = NULL;
sample best_distance = 44100.0;
diff --git a/src/mq/mq.h b/src/mq/mq.h
index d2f24cc..0c6a11c 100644
--- a/src/mq/mq.h
+++ b/src/mq/mq.h
@@ -7,44 +7,91 @@
#include <math.h>
#include <string.h>
+#include "base.h"
+
namespace simpl
{
-typedef double sample;
-
-typedef struct MQPeak {
- float amplitude;
- float frequency;
- float phase;
- int bin;
- struct MQPeak* next;
- struct MQPeak* prev;
-} MQPeak;
-
-typedef struct MQPeakList {
- struct MQPeakList* next;
- struct MQPeakList* prev;
- struct MQPeak* peak;
-} MQPeakList;
-
-typedef struct MQParameters {
- int frame_size;
- int max_peaks;
- int num_bins;
- sample peak_threshold;
- sample fundamental;
- sample matching_interval;
- sample* window;
- sample* fft_in;
- fftw_complex* fft_out;
- fftw_plan fft_plan;
- MQPeakList* prev_peaks;
-} MQParameters;
+// ---------------------------------------------------------------------------
+// MQPeak
+// ---------------------------------------------------------------------------
+class MQPeak {
+ public:
+ float amplitude;
+ float frequency;
+ float phase;
+ int bin;
+ MQPeak* next;
+ MQPeak* prev;
+
+ MQPeak() {
+ amplitude = 0.f;
+ frequency = 0.f;
+ phase = 0.f;
+ bin = 0;
+ next = NULL;
+ prev = NULL;
+ }
+};
+
+
+// ---------------------------------------------------------------------------
+// MQPeakList
+// ---------------------------------------------------------------------------
+class MQPeakList {
+ public:
+ MQPeakList* next;
+ MQPeakList* prev;
+ MQPeak* peak;
+
+ MQPeakList() {
+ next = NULL;
+ prev = NULL;
+ peak = NULL;
+ }
+};
+
+
+// ---------------------------------------------------------------------------
+// MQParameters
+// ---------------------------------------------------------------------------
+class MQParameters {
+ public:
+ int frame_size;
+ int max_peaks;
+ int num_bins;
+ sample peak_threshold;
+ sample fundamental;
+ sample matching_interval;
+ sample* window;
+ sample* fft_in;
+ fftw_complex* fft_out;
+ fftw_plan fft_plan;
+ MQPeakList* prev_peaks;
+
+ MQParameters() {
+ frame_size = 0;
+ max_peaks = 0;
+ num_bins = 0;
+ peak_threshold = 0.f;
+ fundamental = 0.f;
+ matching_interval = 0.f;
+ window = NULL;
+ fft_in = NULL;
+ fft_out = NULL;
+ prev_peaks = NULL;
+ }
+};
+
+// ---------------------------------------------------------------------------
+// MQ functions
+// ---------------------------------------------------------------------------
int init_mq(MQParameters* params);
void reset_mq(MQParameters* params);
int destroy_mq(MQParameters* params);
+void mq_add_peak(MQPeak* new_peak, MQPeakList* peak_list);
void delete_peak_list(MQPeakList* peak_list);
MQPeakList* mq_sort_peaks_by_frequency(MQPeakList* peak_list, int num_peaks);
diff --git a/src/simpl/partial_tracking.cpp b/src/simpl/partial_tracking.cpp
index 420bcb6..c9f7dab 100644
--- a/src/simpl/partial_tracking.cpp
+++ b/src/simpl/partial_tracking.cpp
@@ -79,6 +79,87 @@ Frames PartialTracking::find_partials(Frames frames) {
// ---------------------------------------------------------------------------
+// MQPartialTracking
+// ---------------------------------------------------------------------------
+
+MQPartialTracking::MQPartialTracking() {
+ _mq_params.max_peaks = _max_partials;
+ _mq_params.frame_size = 0;
+ _mq_params.num_bins = 0;
+ _mq_params.peak_threshold = 0.0;
+ _mq_params.matching_interval = 100.0;
+ _mq_params.fundamental = 0;
+ init_mq(&_mq_params);
+ _peak_list = NULL;
+ _prev_peak_list = NULL;
+}
+
+MQPartialTracking::~MQPartialTracking() {
+ destroy_mq(&_mq_params);
+ delete_peak_list(_peak_list);
+ _prev_peak_list = NULL;
+}
+
+void MQPartialTracking::reset() {
+ reset_mq(&_mq_params);
+ delete_peak_list(_peak_list);
+ _prev_peak_list = NULL;
+}
+
+void MQPartialTracking::max_partials(int new_max_partials) {
+ _max_partials = new_max_partials;
+ _mq_params.max_peaks = _max_partials;
+}
+
+Peaks MQPartialTracking::update_partials(Frame* frame) {
+ Peaks peaks;
+ int num_peaks = _max_partials;
+ if(num_peaks > frame->num_peaks()) {
+ num_peaks = frame->num_peaks();
+ }
+ frame->clear_partials();
+
+ _peak_list = new MQPeakList();
+ for(int i = 0; i < num_peaks; i++) {
+ MQPeak* p = new MQPeak();
+ p->amplitude = frame->peak(i)->amplitude;
+ p->frequency = frame->peak(i)->frequency;
+ p->phase = frame->peak(i)->phase;
+ p->bin = i;
+ p->next = NULL;
+ p->prev = NULL;
+ mq_add_peak(p, _peak_list);
+ }
+
+ MQPeakList* partials = mq_track_peaks(_peak_list, &_mq_params);
+ partials = mq_sort_peaks_by_frequency(partials, num_peaks);
+
+ int num_partials = 0;
+ while(partials && partials->peak && (num_partials < _max_partials)) {
+ Peak* p = new Peak();
+ p->amplitude = partials->peak->amplitude;
+ p->frequency = partials->peak->frequency;
+ p->phase = partials->peak->phase;
+ peaks.push_back(p);
+ frame->add_partial(p);
+
+ partials = partials->next;
+ num_partials++;
+ }
+
+ for(int i = num_partials; i < _max_partials; i++) {
+ Peak* p = new Peak();
+ peaks.push_back(p);
+ frame->add_partial(p);
+ }
+
+ delete_peak_list(_prev_peak_list);
+ _prev_peak_list = _peak_list;
+ return peaks;
+}
+
+
+// ---------------------------------------------------------------------------
// SMSPartialTracking
// ---------------------------------------------------------------------------
@@ -368,6 +449,7 @@ Peaks SndObjPartialTracking::update_partials(Frame* frame) {
if(num_peaks > frame->num_peaks()) {
num_peaks = frame->num_peaks();
}
+ frame->clear_partials();
for(int i = 0; i < num_peaks; i++) {
_peak_amplitude[i] = frame->peak(i)->amplitude;
diff --git a/src/simpl/partial_tracking.h b/src/simpl/partial_tracking.h
index 0812e64..1cf022d 100644
--- a/src/simpl/partial_tracking.h
+++ b/src/simpl/partial_tracking.h
@@ -3,6 +3,8 @@
#include "base.h"
+#include "mq.h"
+
extern "C" {
#include "sms.h"
}
@@ -60,6 +62,22 @@ class PartialTracking {
virtual Frames find_partials(Frames frames);
};
+// ---------------------------------------------------------------------------
+// MQPartialTracking
+// ---------------------------------------------------------------------------
+class MQPartialTracking : public PartialTracking {
+ private:
+ MQParameters _mq_params;
+ MQPeakList* _peak_list;
+ MQPeakList* _prev_peak_list;
+
+ public:
+ MQPartialTracking();
+ ~MQPartialTracking();
+ void reset();
+ void max_partials(int new_max_partials);
+ Peaks update_partials(Frame* frame);
+};
// ---------------------------------------------------------------------------
// SMSPartialTracking
diff --git a/tests/test_partial_tracking.cpp b/tests/test_partial_tracking.cpp
index f5c1fff..86cc878 100644
--- a/tests/test_partial_tracking.cpp
+++ b/tests/test_partial_tracking.cpp
@@ -15,6 +15,100 @@ namespace simpl
{
// ---------------------------------------------------------------------------
+// TestMQPartialTracking
+// ---------------------------------------------------------------------------
+class TestMQPartialTracking : public CPPUNIT_NS::TestCase {
+ CPPUNIT_TEST_SUITE(TestMQPartialTracking);
+ CPPUNIT_TEST(test_basic);
+ CPPUNIT_TEST(test_peaks);
+ CPPUNIT_TEST_SUITE_END();
+
+protected:
+ static const double PRECISION = 0.001;
+ MQPeakDetection* pd;
+ MQPartialTracking* pt;
+ SndfileHandle sf;
+ int num_samples;
+
+ void test_basic() {
+ pt->reset();
+ pd->hop_size(256);
+ pd->frame_size(2048);
+
+ sample* audio = new sample[(int)sf.frames()];
+ sf.read(audio, (int)sf.frames());
+
+ Frames frames = pd->find_peaks(
+ num_samples, &(audio[(int)sf.frames() / 2])
+ );
+ frames = pt->find_partials(frames);
+
+ for(int i = 0; i < frames.size(); i++) {
+ CPPUNIT_ASSERT(frames[i]->num_peaks() > 0);
+ CPPUNIT_ASSERT(frames[i]->num_partials() > 0);
+ }
+ }
+
+ void test_peaks() {
+ pt->reset();
+
+ Frames frames;
+ Peaks peaks;
+ int num_frames = 8;
+
+ for(int i = 0; i < num_frames; i++) {
+ Peak* p = new Peak();
+ p->amplitude = 0.2;
+ p->frequency = 220;
+
+ Peak* p2 = new Peak();
+ p2->amplitude = 0.2;
+ p2->frequency = 440;
+
+ Frame* f = new Frame();
+ f->add_peak(p);
+ f->add_peak(p2);
+
+ frames.push_back(f);
+ peaks.push_back(p);
+ peaks.push_back(p2);
+ }
+
+ pt->find_partials(frames);
+ for(int i = 0; i < num_frames; i++) {
+ CPPUNIT_ASSERT(frames[i]->num_peaks() > 0);
+ CPPUNIT_ASSERT(frames[i]->num_partials() > 0);
+ CPPUNIT_ASSERT(frames[i]->partial(0)->amplitude == 0.2);
+ CPPUNIT_ASSERT(frames[i]->partial(0)->frequency == 220);
+ CPPUNIT_ASSERT(frames[i]->partial(1)->amplitude == 0.2);
+ CPPUNIT_ASSERT(frames[i]->partial(1)->frequency == 440);
+ }
+
+ for(int i = 0; i < num_frames * 2; i++) {
+ delete peaks[i];
+ }
+
+ for(int i = 0; i < num_frames; i++) {
+ delete frames[i];
+ }
+ }
+
+public:
+ void setUp() {
+ pd = new MQPeakDetection();
+ pt = new MQPartialTracking();
+ sf = SndfileHandle("../tests/audio/flute.wav");
+ num_samples = 4096;
+ }
+
+ void tearDown() {
+ delete pd;
+ delete pt;
+ }
+};
+
+
+// ---------------------------------------------------------------------------
// TestSMSPartialTracking
// ---------------------------------------------------------------------------
class TestSMSPartialTracking : public CPPUNIT_NS::TestCase {
@@ -38,7 +132,9 @@ protected:
sample* audio = new sample[(int)sf.frames()];
sf.read(audio, (int)sf.frames());
- Frames frames = pd->find_peaks(num_samples, &(audio[(int)sf.frames() / 2]));
+ Frames frames = pd->find_peaks(
+ num_samples, &(audio[(int)sf.frames() / 2])
+ );
frames = pt->find_partials(frames);
for(int i = 0; i < frames.size(); i++) {
@@ -130,7 +226,9 @@ protected:
sample* audio = new sample[(int)sf.frames()];
sf.read(audio, (int)sf.frames());
- Frames frames = pd->find_peaks(num_samples, &(audio[(int)sf.frames() / 2]));
+ Frames frames = pd->find_peaks(
+ num_samples, &(audio[(int)sf.frames() / 2])
+ );
frames = pt->find_partials(frames);
for(int i = 0; i < frames.size(); i++) {
@@ -199,6 +297,7 @@ public:
} // end of namespace simpl
+CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestMQPartialTracking);
CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestSMSPartialTracking);
CPPUNIT_TEST_SUITE_REGISTRATION(simpl::TestLorisPartialTracking);