diff options
Diffstat (limited to 'tests/test_peak_detection.py')
-rw-r--r-- | tests/test_peak_detection.py | 127 |
1 files changed, 51 insertions, 76 deletions
diff --git a/tests/test_peak_detection.py b/tests/test_peak_detection.py index 6d623b6..84505b7 100644 --- a/tests/test_peak_detection.py +++ b/tests/test_peak_detection.py @@ -1,7 +1,5 @@ import os -import numpy as np -from nose.tools import assert_almost_equals -import pysms +import json import simpl import simpl.peak_detection as peak_detection @@ -18,6 +16,16 @@ num_samples = num_frames * hop_size audio_path = os.path.join( os.path.dirname(__file__), 'audio/flute.wav' ) +test_data_path = os.path.join( + os.path.dirname(__file__), 'libsms_test_data.json' +) + + +def _load_libsms_test_data(): + test_data = None + with open(test_data_path, 'r') as f: + test_data = json.loads(f.read()) + return test_data class TestPeakDetection(object): @@ -25,91 +33,37 @@ class TestPeakDetection(object): def setup_class(cls): cls.audio = simpl.read_wav(audio_path)[0] - def test_peak_detection(self): + def test_basic(self): pd = PeakDetection() + pd.max_peaks = max_peaks pd.find_peaks(self.audio) assert len(pd.frames) == len(self.audio) / hop_size assert len(pd.frames[0].peaks) == 0 + assert pd.frames[0].max_peaks == max_peaks class TestSMSPeakDetection(object): - def _pysms_analysis_params(self, sampling_rate): - analysis_params = pysms.SMS_AnalParams() - pysms.sms_initAnalParams(analysis_params) - analysis_params.iSamplingRate = sampling_rate - analysis_params.iFrameRate = sampling_rate / hop_size - analysis_params.iWindowType = pysms.SMS_WIN_HAMMING - analysis_params.fDefaultFundamental = 100 - analysis_params.fHighestFreq = 20000 - analysis_params.iFormat = pysms.SMS_FORMAT_HP - analysis_params.nTracks = max_peaks - analysis_params.peakParams.iMaxPeaks = max_peaks - analysis_params.nGuides = max_peaks - analysis_params.iMaxDelayFrames = 4 - analysis_params.analDelay = 0 - analysis_params.minGoodFrames = 1 - analysis_params.iCleanTracks = 0 - analysis_params.iStochasticType = pysms.SMS_STOC_NONE - analysis_params.preEmphasis = 0 - return analysis_params + @classmethod + def setup_class(cls): + cls.audio = simpl.read_wav(audio_path)[0] + cls.test_data = _load_libsms_test_data() + + def test_basic(self): + pd = SMSPeakDetection() + pd.hop_size = hop_size + pd.static_frame_size = True + pd.find_peaks(self.audio) + + assert len(pd.frames) == len(self.audio) / hop_size + assert len(pd.frames[0].peaks) def test_size_next_read(self): """ - test_size_next_read - Make sure PeakDetection is calculating the correct value for the + Make sure SMSPeakDetection is calculating the correct value for the size of the next frame. """ audio, sampling_rate = simpl.read_wav(audio_path) - pysms.sms_init() - snd_header = pysms.SMS_SndHeader() - - # Try to open the input file to fill snd_header - if(pysms.sms_openSF(audio_path, snd_header)): - raise NameError( - "error opening sound file: " + pysms.sms_errorString() - ) - - analysis_params = self._pysms_analysis_params(sampling_rate) - analysis_params.iMaxDelayFrames = num_frames + 1 - if pysms.sms_initAnalysis(analysis_params, snd_header) != 0: - raise Exception("Error allocating memory for analysis_params") - analysis_params.nFrames = num_frames - sms_header = pysms.SMS_Header() - pysms.sms_fillHeader(sms_header, analysis_params, "pysms") - - sample_offset = 0 - pysms_size_new_data = 0 - current_frame = 0 - sms_next_read_sizes = [] - - while current_frame < num_frames: - sms_next_read_sizes.append(analysis_params.sizeNextRead) - sample_offset += pysms_size_new_data - pysms_size_new_data = analysis_params.sizeNextRead - - # convert frame to floats for libsms - frame = audio[sample_offset:sample_offset + pysms_size_new_data] - frame = np.array(frame, dtype=np.float32) - if len(frame) < pysms_size_new_data: - frame = np.hstack(( - frame, np.zeros(pysms_size_new_data - len(frame), - dtype=np.float32) - )) - - analysis_data = pysms.SMS_Data() - pysms.sms_allocFrameH(sms_header, analysis_data) - status = pysms.sms_analyze(frame, analysis_data, analysis_params) - # as the no. of frames of delay is > num_frames, sms_analyze should - # never get around to performing partial tracking, and so the - # return value should be 0 - assert status == 0 - pysms.sms_freeFrame(analysis_data) - current_frame += 1 - - pysms.sms_freeAnalysis(analysis_params) - pysms.sms_closeSF() - pysms.sms_free() pd = SMSPeakDetection() pd.hop_size = hop_size @@ -117,13 +71,34 @@ class TestSMSPeakDetection(object): current_frame = 0 sample_offset = 0 + next_read_sizes = self.test_data['size_next_read'] + while current_frame < num_frames: pd.frame_size = pd.next_frame_size() - assert sms_next_read_sizes[current_frame] == pd.frame_size,\ - (sms_next_read_sizes[current_frame], pd.frame_size) + assert next_read_sizes[current_frame] == pd.frame_size,\ + (next_read_sizes[current_frame], pd.frame_size) frame = simpl.Frame() frame.size = pd.frame_size frame.audio = audio[sample_offset:sample_offset + pd.frame_size] pd.find_peaks_in_frame(frame) sample_offset += pd.frame_size current_frame += 1 + + def test_peak_detection(self): + audio, sampling_rate = simpl.read_wav(audio_path) + + pd = SMSPeakDetection() + pd.max_peaks = max_peaks + pd.hop_size = hop_size + frames = pd.find_peaks(audio[0:num_samples]) + + sms_frames = self.test_data['peak_detection'] + sms_frames = [f for f in sms_frames if f['status'] != 0] + + print 'frames: %d (expected: %d)' % (len(frames), len(sms_frames)) + assert len(sms_frames) == len(frames) + + for frame in frames: + assert frame.num_peaks <= max_peaks + max_amp = max([p.amplitude for p in frame.peaks]) + assert max_amp |