summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--simpl/base.pxd1
-rw-r--r--simpl/base.pyx11
-rw-r--r--simpl/partial_tracking.pyx12
-rw-r--r--tests/test_partial_tracking.py11
4 files changed, 13 insertions, 22 deletions
diff --git a/simpl/base.pxd b/simpl/base.pxd
index e364e6e..4f45332 100644
--- a/simpl/base.pxd
+++ b/simpl/base.pxd
@@ -20,6 +20,7 @@ cdef class Frame:
cdef int created
cdef set_frame(self, c_Frame* f)
cdef list _peaks
+ cdef list _partials
cdef extern from "<string>" namespace "std":
diff --git a/simpl/base.pyx b/simpl/base.pyx
index e186fee..c86c2e5 100644
--- a/simpl/base.pyx
+++ b/simpl/base.pyx
@@ -37,6 +37,7 @@ cdef class Peak:
cdef class Frame:
def __cinit__(self, size=None, create_new=True, alloc_memory=False):
self._peaks = []
+ self._partials = []
if create_new:
if size:
@@ -78,10 +79,6 @@ cdef class Frame:
self.thisptr.clear()
# partials
- property num_partials:
- def __get__(self): return self.thisptr.num_partials()
- def __set__(self, int i): raise Exception("NotImplemented")
-
property max_partials:
def __get__(self): return self.thisptr.max_partials()
def __set__(self, int i): self.thisptr.max_partials(i)
@@ -105,9 +102,11 @@ cdef class Frame:
property partials:
def __get__(self):
- return [self.partial(i) for i in range(self.thisptr.num_partials())]
+ if not self._partials:
+ self._partials = [self.partial(i) for i in range(self.thisptr.num_partials())]
+ return self._partials
def __set__(self, peaks):
- self.add_partials(peaks)
+ self._partials = peaks
# audio buffers
property size:
diff --git a/simpl/partial_tracking.pyx b/simpl/partial_tracking.pyx
index 4019937..69053a3 100644
--- a/simpl/partial_tracking.pyx
+++ b/simpl/partial_tracking.pyx
@@ -45,18 +45,16 @@ cdef class PartialTracking:
peak = Peak(False)
peak.set_peak(c_peaks[i])
peaks.append(peak)
+ frame.partials = peaks
return peaks
def find_partials(self, frames):
partial_frames = []
- cdef vector[c_Frame*] c_frames
for frame in frames:
- c_frames.push_back((<Frame>frame).thisptr)
- cdef vector[c_Frame*] output_frames = self.thisptr.find_partials(c_frames)
- for i in range(output_frames.size()):
- f = Frame(output_frames[i].size(), False)
- f.set_frame(output_frames[i])
- partial_frames.append(f)
+ if frame.max_partials != self.thisptr.max_partials():
+ frame.max_partials = self.thisptr.max_partials()
+ self.update_partials(frame)
+ partial_frames.append(frame)
return partial_frames
diff --git a/tests/test_partial_tracking.py b/tests/test_partial_tracking.py
index 32f2653..cab6b2b 100644
--- a/tests/test_partial_tracking.py
+++ b/tests/test_partial_tracking.py
@@ -46,10 +46,6 @@ class TestPartialTracking(object):
pt = PartialTracking()
frames = pt.find_partials(frames)
- print 'frames: %d (expected: %d)' %\
- (len(frames), len(self.audio) / hop_size)
- assert len(frames) == len(self.audio) / hop_size
-
assert len(frames[0].partials) == 0
assert frames[0].max_partials == 100
@@ -64,6 +60,7 @@ class TestSMSPartialTracking(object):
def test_basic(self):
pd = SMSPeakDetection()
pd.hop_size = hop_size
+ pd.frame_size = hop_size
pd.max_peaks = max_peaks
pd.static_frame_size = True
frames = pd.find_peaks(self.audio)
@@ -72,11 +69,7 @@ class TestSMSPartialTracking(object):
pt.max_partials = max_partials
frames = pt.find_partials(frames)
- print 'frames: %d (expected: %d)' %\
- (len(frames), len(self.audio) / hop_size)
- assert len(frames) == len(self.audio) / hop_size
-
- assert frames[0].num_partials == max_partials
+ assert len(frames[0].partials) == max_partials
assert frames[0].max_partials == max_partials
def test_partial_tracking(self):