From f0576d53f10832adb8a491f85ec86d2219a621bf Mon Sep 17 00:00:00 2001
From: John Glover <glover.john@gmail.com>
Date: Fri, 3 Dec 2010 15:51:21 +0000
Subject: Fixed bug in harmDetection.c, added more SMS tests

---
 tests/sms.py | 364 +++++++++++++++++++++++++++++++++++------------------------
 1 file changed, 216 insertions(+), 148 deletions(-)

(limited to 'tests/sms.py')

diff --git a/tests/sms.py b/tests/sms.py
index d30686b..4073caf 100644
--- a/tests/sms.py
+++ b/tests/sms.py
@@ -20,9 +20,8 @@ import pysms
 import numpy as np
 from scipy.io.wavfile import read
 from nose.tools import assert_almost_equals
-import unittest
 
-class TestSimplSMS(unittest.TestCase):
+class TestSimplSMS(object):
     FLOAT_PRECISION = 3 # number of decimal places to check for accuracy
     input_file = 'audio/flute.wav'
     frame_size = 2048
@@ -30,7 +29,7 @@ class TestSimplSMS(unittest.TestCase):
     num_frames = 9 
     num_samples = frame_size + ((num_frames - 1) * hop_size)
     max_peaks = 10
-    max_partials = 3
+    max_partials = 10
 
     def get_audio(self):
         audio_data = read(self.input_file)
@@ -116,7 +115,7 @@ class TestSimplSMS(unittest.TestCase):
             # as the no. of frames of delay is > num_frames, sms_analyze should
             # never get around to performing partial tracking, and so the return
             # value should be 0
-            self.assertEquals(status, 0)
+            assert status == 0
             pysms.sms_freeFrame(analysis_data)
             current_frame += 1
 
@@ -132,13 +131,13 @@ class TestSimplSMS(unittest.TestCase):
 
         while current_frame < self.num_frames:
             pd.frame_size = pd.get_next_frame_size()
-            self.assertEquals(sms_next_read_sizes[current_frame], pd.frame_size)
+            assert sms_next_read_sizes[current_frame] == pd.frame_size
             pd.find_peaks_in_frame(audio[sample_offset:sample_offset + pd.frame_size])
             sample_offset += pd.frame_size
             current_frame += 1
 
     def test_sms_analyze(self):
-        """test_sms_analyzebt43lztar
+        """test_sms_analyze
         Make sure that the simplsms.sms_analyze function does the same thing
         as the sms_analyze function from libsms."""
         audio, sampling_rate = self.get_audio()
@@ -275,21 +274,21 @@ class TestSimplSMS(unittest.TestCase):
         simplsms.sms_free()
 
         # make sure both have the same number of partials
-        self.assertEquals(len(sms_partials), len(simplsms_partials))
+        assert len(sms_partials) == len(simplsms_partials)
 
         # make sure each partial is the same
         for i in range(len(sms_partials)):
-            self.assertEquals(sms_partials[i].get_length(), simplsms_partials[i].get_length())
+            assert sms_partials[i].get_length() == simplsms_partials[i].get_length()
             for peak_number in range(sms_partials[i].get_length()):
-                self.assertAlmostEquals(sms_partials[i].peaks[peak_number].amplitude,
-                                        simplsms_partials[i].peaks[peak_number].amplitude,
-                                        places = self.FLOAT_PRECISION)
-                self.assertAlmostEquals(sms_partials[i].peaks[peak_number].frequency,
-                                        simplsms_partials[i].peaks[peak_number].frequency,
-                                        places = self.FLOAT_PRECISION)
-                self.assertAlmostEquals(sms_partials[i].peaks[peak_number].phase,
-                                        simplsms_partials[i].peaks[peak_number].phase,
-                                        places = self.FLOAT_PRECISION)
+                assert_almost_equals(sms_partials[i].peaks[peak_number].amplitude,
+                                     simplsms_partials[i].peaks[peak_number].amplitude,
+                                     self.FLOAT_PRECISION)
+                assert_almost_equals(sms_partials[i].peaks[peak_number].frequency,
+                                     simplsms_partials[i].peaks[peak_number].frequency,
+                                     self.FLOAT_PRECISION)
+                assert_almost_equals(sms_partials[i].peaks[peak_number].phase,
+                                     simplsms_partials[i].peaks[peak_number].phase,
+                                     self.FLOAT_PRECISION)
 
     def test_multi_sms_peak_detection(self): 
         """test_multi_sms_peak_detection
@@ -321,7 +320,7 @@ class TestSimplSMS(unittest.TestCase):
             # as the no. of frames of delay is > num_frames, sms_analyze should
             # never get around to performing partial tracking, and so the return
             # value should be 0
-            self.assertEquals(status, 0)
+            assert status == 0
             num_peaks = analysis_data.nTracks
             frame_peaks = []
             simplsms_freqs = simpl.zeros(num_peaks)
@@ -371,7 +370,7 @@ class TestSimplSMS(unittest.TestCase):
             # as the no. of frames of delay is > num_frames, sms_analyze should
             # never get around to performing partial tracking, and so the return
             # value should be 0
-            self.assertEquals(status, 0)
+            assert status == 0
             num_peaks = analysis_data.nTracks
             frame_peaks = []
             simplsms_freqs = simpl.zeros(num_peaks)
@@ -484,7 +483,7 @@ class TestSimplSMS(unittest.TestCase):
             # as the no. of frames of delay is > num_frames, sms_analyze should
             # never get around to performing partial tracking, and so the return
             # value should be 0
-            self.assertEquals(status, 0)
+            assert status == 0
             num_peaks = analysis_data.nTracks
             frame_peaks = []
             simplsms_freqs = simpl.zeros(num_peaks)
@@ -524,25 +523,24 @@ class TestSimplSMS(unittest.TestCase):
             current_frame += 1
 
         # make sure we have the same number of frames
-        self.assertEquals(len(sms_peaks), len(simpl_peaks))
+        assert len(sms_peaks) == len(simpl_peaks)
 
         # compare data for each frame
         for frame_number in range(len(sms_peaks)):
             sms_frame = sms_peaks[frame_number]
             simpl_frame = simpl_peaks[frame_number]
             # make sure we have the same number of peaks in each frame
-            self.assertEquals(len(sms_frame), len(simpl_frame))
+            assert len(sms_frame) == len(simpl_frame)
             # check peak values
             for peak_number in range(len(sms_frame)):
-                #print frame_number, peak_number
                 sms_peak = sms_frame[peak_number]
                 simpl_peak = simpl_frame[peak_number]
-                self.assertAlmostEquals(sms_peak.amplitude, simpl_peak.amplitude,
-                                        places=self.FLOAT_PRECISION)
-                self.assertAlmostEquals(sms_peak.frequency, simpl_peak.frequency,
-                                        places=self.FLOAT_PRECISION)
-                self.assertAlmostEquals(sms_peak.phase, simpl_peak.phase,
-                                        places=self.FLOAT_PRECISION)  
+                assert_almost_equals(sms_peak.amplitude, simpl_peak.amplitude,
+                                     self.FLOAT_PRECISION)
+                assert_almost_equals(sms_peak.frequency, simpl_peak.frequency,
+                                     self.FLOAT_PRECISION)
+                assert_almost_equals(sms_peak.phase, simpl_peak.phase,
+                                     self.FLOAT_PRECISION)  
 
     def test_multi_pysms_analyze(self): 
         """test_multi_pysms_analyze
@@ -786,129 +784,203 @@ class TestSimplSMS(unittest.TestCase):
         pt.max_partials = self.max_peaks
         partials = pt.find_partials(peaks)
 
-        #import debug
-        #debug.print_partials(sms_partials)
-        #print
-        #debug.print_partials(partials)
-        #raise Exception("ok")
-
         # make sure both have the same number of partials
-        self.assertEquals(len(sms_partials), len(partials))
+        assert len(sms_partials) == len(partials)
 
         # make sure each partial is the same
         for i in range(len(sms_partials)):
-            self.assertEquals(sms_partials[i].get_length(), partials[i].get_length())
+            assert sms_partials[i].get_length() == partials[i].get_length()
             for peak_number in range(sms_partials[i].get_length()):
-                self.assertAlmostEquals(sms_partials[i].peaks[peak_number].amplitude,
-                                        partials[i].peaks[peak_number].amplitude,
-                                        places = self.FLOAT_PRECISION)
-                self.assertAlmostEquals(sms_partials[i].peaks[peak_number].frequency,
-                                        partials[i].peaks[peak_number].frequency,
-                                        places = self.FLOAT_PRECISION)
-                self.assertAlmostEquals(sms_partials[i].peaks[peak_number].phase,
-                                        partials[i].peaks[peak_number].phase,
-                                        places = self.FLOAT_PRECISION)
-
-   #def test_interpolate_frames(self):
-    #    """test_interpolate_frames
-    #    Make sure that pysms.sms_interpolateFrames returns the expected values
-    #    with interpolation factors of 0 and 1."""
-    #    pysms.sms_init()
-    #    sms_header = pysms.SMS_Header()
-    #    snd_header = pysms.SMS_SndHeader()
-    #    # Try to open the input file to fill snd_header
-    #    if(pysms.sms_openSF(input_file, snd_header)):
-    #        raise NameError("error opening sound file: " + pysms.sms_errorString())
-    #    analysis_params = pysms.SMS_AnalParams()
-    #    analysis_params.iSamplingRate = 44100
-    #    analysis_params.iFrameRate = sampling_rate / hop_size
-    #    sms_header.nStochasticCoeff = 128
-    #    analysis_params.fDefaultFundamental = 100
-    #    analysis_params.fHighestFreq = 20000
-    #    analysis_params.iMaxDelayFrames = 3
-    #    analysis_params.analDelay = 0
-    #    analysis_params.minGoodFrames = 1
-    #    analysis_params.iFormat = pysms.SMS_FORMAT_HP
-    #    analysis_params.nTracks = max_partials
-    #    analysis_params.nGuides = max_partials
-    #    analysis_params.iWindowType = pysms.SMS_WIN_HAMMING
-    #    pysms.sms_initAnalysis(analysis_params, snd_header)
-    #    analysis_params.nFrames = num_samples / hop_size
-    #    analysis_params.iSizeSound = num_samples
-    #    analysis_params.peakParams.iMaxPeaks = max_peaks
-    #    analysis_params.iStochasticType = pysms.SMS_STOC_NONE
-    #    pysms.sms_fillHeader(sms_header, analysis_params, "pysms")
-    #    interp_frame = pysms.SMS_Data()
-    #    pysms.sms_allocFrame(interp_frame, sms_header.nTracks, sms_header.nStochasticCoeff, 1, sms_header.iStochasticType, 0)
+                assert_almost_equals(sms_partials[i].peaks[peak_number].amplitude,
+                                     partials[i].peaks[peak_number].amplitude,
+                                     self.FLOAT_PRECISION)
+                assert_almost_equals(sms_partials[i].peaks[peak_number].frequency,
+                                     partials[i].peaks[peak_number].frequency,
+                                     self.FLOAT_PRECISION)
+                assert_almost_equals(sms_partials[i].peaks[peak_number].phase,
+                                     partials[i].peaks[peak_number].phase,
+                                     self.FLOAT_PRECISION)
 
-    #    sample_offset = 0
-    #    size_new_data = 0
-    #    current_frame = 0
-    #    sms_header.nFrames = num_frames
-    #    analysis_frames = []
-    #    do_analysis = True
+    def test_sms_interpolate_frames(self):
+        """test_sms_interpolate_frames
+        Make sure that sms_interpolateFrames returns the expected values
+        with interpolation factors of 0 and 1."""
+        audio, sampling_rate = self.get_audio()
+        pysms.sms_init()
+        snd_header = pysms.SMS_SndHeader()
+        # Try to open the input file to fill snd_header
+        if(pysms.sms_openSF(self.input_file, snd_header)):
+            raise NameError("error opening sound file: " + pysms.sms_errorString())
+        analysis_params = self.pysms_analysis_params(sampling_rate)
+        analysis_params.nFrames = self.num_frames
+        if pysms.sms_initAnalysis(analysis_params, snd_header) != 0:
+            raise Exception("Error allocating memory for analysis_params")
+        analysis_params.iSizeSound = self.num_samples
+        sms_header = pysms.SMS_Header()
+        pysms.sms_fillHeader(sms_header, analysis_params, "pysms")
 
-    #    while do_analysis and (current_frame < num_frames):
-    #        sample_offset += size_new_data
-    #        if((sample_offset + analysis_params.sizeNextRead) < num_samples):
-    #            size_new_data = analysis_params.sizeNextRead
-    #        else:
-    #            size_new_data = num_samples - sample_offset
-    #        frame = audio[sample_offset:sample_offset + size_new_data]
-    #        analysis_data = pysms.SMS_Data()
-    #        pysms.sms_allocFrameH(sms_header, analysis_data)
-    #        status = pysms.sms_analyze(frame, analysis_data, analysis_params)  
+        interp_frame = pysms.SMS_Data()
+        pysms.sms_allocFrameH(sms_header, interp_frame)
 
-    #        if status == 1:
-    #            analysis_frames.append(analysis_data)
-    #            # test interpolateFrames on the last two analysis frames
-    #            if current_frame == num_frames - 1:
-    #                left_frame = analysis_frames[-2]
-    #                right_frame = analysis_frames[-1]
-    #                pysms.sms_interpolateFrames(left_frame, right_frame, interp_frame, 0)
-    #                # make sure that interp_frame == left_frame
-    #                # interpolateFrames doesn't interpolate phases so ignore
-    #                left_amps = simpl.zeros(max_partials)
-    #                left_freqs = simpl.zeros(max_partials)
-    #                left_frame.getSinAmp(left_amps)
-    #                left_frame.getSinFreq(left_freqs)
-    #                right_amps = simpl.zeros(max_partials)
-    #                right_freqs = simpl.zeros(max_partials)
-    #                right_frame.getSinAmp(right_amps)
-    #                right_frame.getSinFreq(right_freqs)
-    #                interp_amps = simpl.zeros(max_partials)
-    #                interp_freqs = simpl.zeros(max_partials)
-    #                interp_frame.getSinAmp(interp_amps)
-    #                interp_frame.getSinFreq(interp_freqs)
-    #                for i in range(max_partials):
-    #                    self.assertAlmostEquals(left_amps[i], interp_amps[i],
-    #                                            places = FLOAT_PRECISION)
-    #                    if left_freqs[i] != 0:
-    #                        self.assertAlmostEquals(left_freqs[i], interp_freqs[i],
-    #                                                places = FLOAT_PRECISION)
-    #                    else:
-    #                        self.assertAlmostEquals(right_freqs[i], interp_freqs[i],
-    #                                                places = FLOAT_PRECISION)
-    #                pysms.sms_interpolateFrames(left_frame, right_frame, interp_frame, 1)
-    #                interp_amps = simpl.zeros(max_partials)
-    #                interp_freqs = simpl.zeros(max_partials)
-    #                interp_frame.getSinAmp(interp_amps)
-    #                interp_frame.getSinFreq(interp_freqs)
-    #                for i in range(max_partials):
-    #                    self.assertAlmostEquals(right_amps[i], interp_amps[i],
-    #                                            places = FLOAT_PRECISION)
-    #                    if right_freqs[i] != 0:
-    #                        self.assertAlmostEquals(right_freqs[i], interp_freqs[i],
-    #                                                places = FLOAT_PRECISION)
-    #                    else:
-    #                        self.assertAlmostEquals(left_freqs[i], interp_freqs[i],
-    #                                                places = FLOAT_PRECISION)
-    #        elif status == -1:
-    #            raise Exception("AnalysisStoppedEarly")
-    #        current_frame += 1
+        sample_offset = 0
+        size_new_data = 0
+        current_frame = 0
+        analysis_frames = []
+        do_analysis = True
 
-    #    pysms.sms_freeAnalysis(analysis_params)
-    #    pysms.sms_closeSF()
+        while do_analysis and (current_frame < self.num_frames):
+            sample_offset += size_new_data
+            size_new_data = analysis_params.sizeNextRead
+            frame = audio[sample_offset:sample_offset + size_new_data]
+            # convert frame to floats for libsms
+            frame = np.array(frame, dtype=np.float32)
+            analysis_data = pysms.SMS_Data()
+            pysms.sms_allocFrameH(sms_header, analysis_data)
+            status = pysms.sms_analyze(frame, analysis_data, analysis_params)  
+
+            if status == 1:
+                analysis_frames.append(analysis_data)
+                # test interpolateFrames on the last two analysis frames
+                if current_frame == self.num_frames - 1:
+                    left_frame = analysis_frames[-2]
+                    right_frame = analysis_frames[-1]
+                    pysms.sms_interpolateFrames(left_frame, right_frame, interp_frame, 0)
+                    # make sure that interp_frame == left_frame
+                    # interpolateFrames doesn't interpolate phases so ignore
+                    left_amps = np.zeros(self.max_partials, dtype=np.float32)
+                    left_freqs = np.zeros(self.max_partials, dtype=np.float32)
+                    left_frame.getSinAmp(left_amps)
+                    left_frame.getSinFreq(left_freqs)
+                    right_amps = np.zeros(self.max_partials, dtype=np.float32)
+                    right_freqs = np.zeros(self.max_partials, dtype=np.float32)
+                    right_frame.getSinAmp(right_amps)
+                    right_frame.getSinFreq(right_freqs)
+                    interp_amps = np.zeros(self.max_partials, dtype=np.float32)
+                    interp_freqs = np.zeros(self.max_partials, dtype=np.float32)
+                    interp_frame.getSinAmp(interp_amps)
+                    interp_frame.getSinFreq(interp_freqs)
+                    for i in range(self.max_partials):
+                        assert_almost_equals(left_amps[i], interp_amps[i],
+                                             self.FLOAT_PRECISION)
+                        if left_freqs[i] != 0:
+                            assert_almost_equals(left_freqs[i], interp_freqs[i],
+                                                 self.FLOAT_PRECISION)
+                        else:
+                            assert_almost_equals(right_freqs[i], interp_freqs[i],
+                                                 self.FLOAT_PRECISION)
+                    pysms.sms_interpolateFrames(left_frame, right_frame, interp_frame, 1)
+                    interp_amps = np.zeros(self.max_partials, dtype=np.float32)
+                    interp_freqs = np.zeros(self.max_partials, dtype=np.float32)
+                    interp_frame.getSinAmp(interp_amps)
+                    interp_frame.getSinFreq(interp_freqs)
+                    for i in range(self.max_partials):
+                        assert_almost_equals(right_amps[i], interp_amps[i],
+                                             self.FLOAT_PRECISION)
+                        if right_freqs[i] != 0:
+                            assert_almost_equals(right_freqs[i], interp_freqs[i],
+                                                 self.FLOAT_PRECISION)
+                        else:
+                            assert_almost_equals(left_freqs[i], interp_freqs[i],
+                                                 self.FLOAT_PRECISION)
+            elif status == -1:
+                raise Exception("AnalysisStoppedEarly")
+            else:
+                pysms.sms_freeFrame(analysis_data)
+            current_frame += 1
+
+        for frame in analysis_frames:
+            pysms.sms_freeFrame(frame)
+        pysms.sms_freeFrame(interp_frame)
+        pysms.sms_freeAnalysis(analysis_params)
+        pysms.sms_closeSF()
+        pysms.sms_free()
+
+    def test_simplsms_interpolate_frames(self):
+        """test_simplsms_interpolate_frames
+        Make sure that sms_interpolateFrames returns the expected values
+        with interpolation factors of 0 and 1."""
+        audio, sampling_rate = self.get_audio()
+        simplsms.sms_init()
+        analysis_params = self.simplsms_analysis_params(sampling_rate)
+        analysis_params.nFrames = self.num_frames
+        if simplsms.sms_initAnalysis(analysis_params) != 0:
+            raise Exception("Error allocating memory for analysis_params")
+        analysis_params.iSizeSound = self.num_samples
+        sms_header = simplsms.SMS_Header()
+        simplsms.sms_fillHeader(sms_header, analysis_params, "simplsms")
+
+        interp_frame = simplsms.SMS_Data()
+        simplsms.sms_allocFrameH(sms_header, interp_frame)
+
+        sample_offset = 0
+        size_new_data = 0
+        current_frame = 0
+        analysis_frames = []
+        do_analysis = True
+
+        while do_analysis and (current_frame < self.num_frames):
+            sample_offset += size_new_data
+            size_new_data = analysis_params.sizeNextRead
+            frame = audio[sample_offset:sample_offset + size_new_data]
+            analysis_data = simplsms.SMS_Data()
+            simplsms.sms_allocFrameH(sms_header, analysis_data)
+            status = simplsms.sms_analyze(frame, analysis_data, analysis_params)  
+
+            if status == 1:
+                analysis_frames.append(analysis_data)
+                # test interpolateFrames on the last two analysis frames
+                if current_frame == self.num_frames - 1:
+                    left_frame = analysis_frames[-2]
+                    right_frame = analysis_frames[-1]
+                    simplsms.sms_interpolateFrames(left_frame, right_frame, interp_frame, 0)
+                    # make sure that interp_frame == left_frame
+                    # interpolateFrames doesn't interpolate phases so ignore
+                    left_amps = simpl.zeros(self.max_partials)
+                    left_freqs = simpl.zeros(self.max_partials)
+                    left_frame.getSinAmp(left_amps)
+                    left_frame.getSinFreq(left_freqs)
+                    right_amps = simpl.zeros(self.max_partials)
+                    right_freqs = simpl.zeros(self.max_partials)
+                    right_frame.getSinAmp(right_amps)
+                    right_frame.getSinFreq(right_freqs)
+                    interp_amps = simpl.zeros(self.max_partials)
+                    interp_freqs = simpl.zeros(self.max_partials)
+                    interp_frame.getSinAmp(interp_amps)
+                    interp_frame.getSinFreq(interp_freqs)
+                    for i in range(self.max_partials):
+                        assert_almost_equals(left_amps[i], interp_amps[i],
+                                             self.FLOAT_PRECISION)
+                        if left_freqs[i] != 0:
+                            assert_almost_equals(left_freqs[i], interp_freqs[i],
+                                                 self.FLOAT_PRECISION)
+                        else:
+                            assert_almost_equals(right_freqs[i], interp_freqs[i],
+                                                 self.FLOAT_PRECISION)
+                    simplsms.sms_interpolateFrames(left_frame, right_frame, interp_frame, 1)
+                    interp_amps = simpl.zeros(self.max_partials)
+                    interp_freqs = simpl.zeros(self.max_partials)
+                    interp_frame.getSinAmp(interp_amps)
+                    interp_frame.getSinFreq(interp_freqs)
+                    for i in range(self.max_partials):
+                        assert_almost_equals(right_amps[i], interp_amps[i],
+                                             self.FLOAT_PRECISION)
+                        if right_freqs[i] != 0:
+                            assert_almost_equals(right_freqs[i], interp_freqs[i],
+                                                 self.FLOAT_PRECISION)
+                        else:
+                            assert_almost_equals(left_freqs[i], interp_freqs[i],
+                                                 self.FLOAT_PRECISION)
+            elif status == -1:
+                raise Exception("AnalysisStoppedEarly")
+            else:
+                simplsms.sms_freeFrame(analysis_data)
+            current_frame += 1
+
+        for frame in analysis_frames:
+            simplsms.sms_freeFrame(frame)
+        simplsms.sms_freeFrame(interp_frame)
+        simplsms.sms_freeAnalysis(analysis_params)
+        simplsms.sms_free()
 
     #def test_harmonic_synthesis(self):
     #    """test_harmonic_synthesis
@@ -1101,10 +1173,6 @@ if __name__ == "__main__":
     # useful for debugging, particularly with GDB
     import nose
     argv = [__file__, 
-            #__file__ + ":TestSimplSMS.test_multi_sms_peak_detection",
-            #__file__ + ":TestSimplSMS.test_multi_simpl_peak_detection",
-            #__file__ + ":TestSimplSMS.test_multi_pysms_analyze",
-            #__file__ + ":TestSimplSMS.test_multi_simpl_partial_tracking",
-            __file__ + ":TestSimplSMS.test_partial_tracking"]
+            __file__ + ":TestSimplSMS.test_simplsms_interpolate_frames"]
     nose.run(argv=argv)
 
-- 
cgit v1.2.3