summaryrefslogtreecommitdiff
path: root/tests/test_peak_detection.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_peak_detection.py')
-rw-r--r--tests/test_peak_detection.py127
1 files changed, 51 insertions, 76 deletions
diff --git a/tests/test_peak_detection.py b/tests/test_peak_detection.py
index 6d623b6..84505b7 100644
--- a/tests/test_peak_detection.py
+++ b/tests/test_peak_detection.py
@@ -1,7 +1,5 @@
import os
-import numpy as np
-from nose.tools import assert_almost_equals
-import pysms
+import json
import simpl
import simpl.peak_detection as peak_detection
@@ -18,6 +16,16 @@ num_samples = num_frames * hop_size
audio_path = os.path.join(
os.path.dirname(__file__), 'audio/flute.wav'
)
+test_data_path = os.path.join(
+ os.path.dirname(__file__), 'libsms_test_data.json'
+)
+
+
+def _load_libsms_test_data():
+ test_data = None
+ with open(test_data_path, 'r') as f:
+ test_data = json.loads(f.read())
+ return test_data
class TestPeakDetection(object):
@@ -25,91 +33,37 @@ class TestPeakDetection(object):
def setup_class(cls):
cls.audio = simpl.read_wav(audio_path)[0]
- def test_peak_detection(self):
+ def test_basic(self):
pd = PeakDetection()
+ pd.max_peaks = max_peaks
pd.find_peaks(self.audio)
assert len(pd.frames) == len(self.audio) / hop_size
assert len(pd.frames[0].peaks) == 0
+ assert pd.frames[0].max_peaks == max_peaks
class TestSMSPeakDetection(object):
- def _pysms_analysis_params(self, sampling_rate):
- analysis_params = pysms.SMS_AnalParams()
- pysms.sms_initAnalParams(analysis_params)
- analysis_params.iSamplingRate = sampling_rate
- analysis_params.iFrameRate = sampling_rate / hop_size
- analysis_params.iWindowType = pysms.SMS_WIN_HAMMING
- analysis_params.fDefaultFundamental = 100
- analysis_params.fHighestFreq = 20000
- analysis_params.iFormat = pysms.SMS_FORMAT_HP
- analysis_params.nTracks = max_peaks
- analysis_params.peakParams.iMaxPeaks = max_peaks
- analysis_params.nGuides = max_peaks
- analysis_params.iMaxDelayFrames = 4
- analysis_params.analDelay = 0
- analysis_params.minGoodFrames = 1
- analysis_params.iCleanTracks = 0
- analysis_params.iStochasticType = pysms.SMS_STOC_NONE
- analysis_params.preEmphasis = 0
- return analysis_params
+ @classmethod
+ def setup_class(cls):
+ cls.audio = simpl.read_wav(audio_path)[0]
+ cls.test_data = _load_libsms_test_data()
+
+ def test_basic(self):
+ pd = SMSPeakDetection()
+ pd.hop_size = hop_size
+ pd.static_frame_size = True
+ pd.find_peaks(self.audio)
+
+ assert len(pd.frames) == len(self.audio) / hop_size
+ assert len(pd.frames[0].peaks)
def test_size_next_read(self):
"""
- test_size_next_read
- Make sure PeakDetection is calculating the correct value for the
+ Make sure SMSPeakDetection is calculating the correct value for the
size of the next frame.
"""
audio, sampling_rate = simpl.read_wav(audio_path)
- pysms.sms_init()
- snd_header = pysms.SMS_SndHeader()
-
- # Try to open the input file to fill snd_header
- if(pysms.sms_openSF(audio_path, snd_header)):
- raise NameError(
- "error opening sound file: " + pysms.sms_errorString()
- )
-
- analysis_params = self._pysms_analysis_params(sampling_rate)
- analysis_params.iMaxDelayFrames = num_frames + 1
- if pysms.sms_initAnalysis(analysis_params, snd_header) != 0:
- raise Exception("Error allocating memory for analysis_params")
- analysis_params.nFrames = num_frames
- sms_header = pysms.SMS_Header()
- pysms.sms_fillHeader(sms_header, analysis_params, "pysms")
-
- sample_offset = 0
- pysms_size_new_data = 0
- current_frame = 0
- sms_next_read_sizes = []
-
- while current_frame < num_frames:
- sms_next_read_sizes.append(analysis_params.sizeNextRead)
- sample_offset += pysms_size_new_data
- pysms_size_new_data = analysis_params.sizeNextRead
-
- # convert frame to floats for libsms
- frame = audio[sample_offset:sample_offset + pysms_size_new_data]
- frame = np.array(frame, dtype=np.float32)
- if len(frame) < pysms_size_new_data:
- frame = np.hstack((
- frame, np.zeros(pysms_size_new_data - len(frame),
- dtype=np.float32)
- ))
-
- analysis_data = pysms.SMS_Data()
- pysms.sms_allocFrameH(sms_header, analysis_data)
- status = pysms.sms_analyze(frame, analysis_data, analysis_params)
- # as the no. of frames of delay is > num_frames, sms_analyze should
- # never get around to performing partial tracking, and so the
- # return value should be 0
- assert status == 0
- pysms.sms_freeFrame(analysis_data)
- current_frame += 1
-
- pysms.sms_freeAnalysis(analysis_params)
- pysms.sms_closeSF()
- pysms.sms_free()
pd = SMSPeakDetection()
pd.hop_size = hop_size
@@ -117,13 +71,34 @@ class TestSMSPeakDetection(object):
current_frame = 0
sample_offset = 0
+ next_read_sizes = self.test_data['size_next_read']
+
while current_frame < num_frames:
pd.frame_size = pd.next_frame_size()
- assert sms_next_read_sizes[current_frame] == pd.frame_size,\
- (sms_next_read_sizes[current_frame], pd.frame_size)
+ assert next_read_sizes[current_frame] == pd.frame_size,\
+ (next_read_sizes[current_frame], pd.frame_size)
frame = simpl.Frame()
frame.size = pd.frame_size
frame.audio = audio[sample_offset:sample_offset + pd.frame_size]
pd.find_peaks_in_frame(frame)
sample_offset += pd.frame_size
current_frame += 1
+
+ def test_peak_detection(self):
+ audio, sampling_rate = simpl.read_wav(audio_path)
+
+ pd = SMSPeakDetection()
+ pd.max_peaks = max_peaks
+ pd.hop_size = hop_size
+ frames = pd.find_peaks(audio[0:num_samples])
+
+ sms_frames = self.test_data['peak_detection']
+ sms_frames = [f for f in sms_frames if f['status'] != 0]
+
+ print 'frames: %d (expected: %d)' % (len(frames), len(sms_frames))
+ assert len(sms_frames) == len(frames)
+
+ for frame in frames:
+ assert frame.num_peaks <= max_peaks
+ max_amp = max([p.amplitude for p in frame.peaks])
+ assert max_amp