summaryrefslogtreecommitdiff
path: root/tests/test_partial_tracking.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_partial_tracking.py')
-rw-r--r--tests/test_partial_tracking.py36
1 files changed, 20 insertions, 16 deletions
diff --git a/tests/test_partial_tracking.py b/tests/test_partial_tracking.py
index fb87da9..32f2653 100644
--- a/tests/test_partial_tracking.py
+++ b/tests/test_partial_tracking.py
@@ -1,6 +1,5 @@
import os
import json
-import numpy as np
from nose.tools import assert_almost_equals
import simpl
import simpl.peak_detection as peak_detection
@@ -11,7 +10,7 @@ SMSPeakDetection = peak_detection.SMSPeakDetection
PartialTracking = partial_tracking.PartialTracking
SMSPartialTracking = partial_tracking.SMSPartialTracking
-float_precision = 5
+float_precision = 2
frame_size = 512
hop_size = 512
max_peaks = 10
@@ -77,7 +76,7 @@ class TestSMSPartialTracking(object):
(len(frames), len(self.audio) / hop_size)
assert len(frames) == len(self.audio) / hop_size
- assert len(frames[0].partials) == 0
+ assert frames[0].num_partials == max_partials
assert frames[0].max_partials == max_partials
def test_partial_tracking(self):
@@ -89,16 +88,21 @@ class TestSMSPartialTracking(object):
pt.max_partials = max_partials
frames = pt.find_partials(peaks)
- # make sure each partial is the same
- # for i in range(len(sms_frames)):
- # assert len(sms_frames[i].partials) == len(simpl_frames[i].partials)
- # for p in range(len(sms_frames[i].partials)):
- # assert_almost_equals(sms_frames[i].partials[p].amplitude,
- # simpl_frames[i].partials[p].amplitude,
- # float_precision)
- # assert_almost_equals(sms_frames[i].partials[p].frequency,
- # simpl_frames[i].partials[p].frequency,
- # float_precision)
- # assert_almost_equals(sms_frames[i].partials[p].phase,
- # simpl_frames[i].partials[p].phase,
- # float_precision)
+ sms_frames = self.test_data['partial_tracking']
+ sms_frames = sms_frames[0:len(sms_frames) - 3]
+
+ assert len(sms_frames) == len(frames)
+
+ for i in range(len(frames)):
+ assert len(frames[i].partials) == len(sms_frames[i]['partials'])
+
+ for p in range(len(frames[i].partials)):
+ assert_almost_equals(frames[i].partials[p].amplitude,
+ sms_frames[i]['partials'][p]['amplitude'],
+ float_precision)
+ assert_almost_equals(frames[i].partials[p].frequency,
+ sms_frames[i]['partials'][p]['frequency'],
+ float_precision)
+ assert_almost_equals(frames[i].partials[p].phase,
+ sms_frames[i]['partials'][p]['phase'],
+ float_precision)