diff --git a/plugins/examples/stft/src/lib.rs b/plugins/examples/stft/src/lib.rs index 1dcf9ecd..95f4e5c7 100644 --- a/plugins/examples/stft/src/lib.rs +++ b/plugins/examples/stft/src/lib.rs @@ -1,6 +1,8 @@ use nih_plug::prelude::*; use std::pin::Pin; +const WINDOW_SIZE: usize = 2048; + struct Stft { params: Pin>, @@ -15,7 +17,7 @@ impl Default for Stft { Self { params: Box::pin(StftParams::default()), - stft: util::StftHelper::new(2, 512), + stft: util::StftHelper::new(2, WINDOW_SIZE), } } } @@ -56,7 +58,7 @@ impl Plugin for Stft { ) -> bool { // Normally we'd also initialize the STFT helper for the correct channel count here, but we // only do stereo so that's not necessary - self.stft.set_block_size(512); + self.stft.set_block_size(WINDOW_SIZE); context.set_latency_samples(self.stft.latency_samples()); true diff --git a/src/util/stft.rs b/src/util/stft.rs index 0072a405..f1021167 100644 --- a/src/util/stft.rs +++ b/src/util/stft.rs @@ -1,7 +1,5 @@ //! Utilities for buffering audio, likely used as part of a short-term Fourier transform. -use std::mem; - use crate::buffer::Buffer; /// Process the input buffer in equal sized blocks, running a callback on each block to transform @@ -16,16 +14,17 @@ use crate::buffer::Buffer; /// TODO: We may need something like this purely for analysis, e.g. for showing spectrums in a GUI. /// Figure out the cleanest way to adapt this for the non-processing use case. pub struct StftHelper { - // These ring buffers store both the input samples and the already processed output. Whenever we - // wrap around,we'll write the already calculated outputs to the main buffer passed to the - // process function and process a new block. - main_ring_buffers: Vec>, + // These ring buffers store the input samples and the already processed output produced by + // adding overlapping windows. Whenever we reach a new overlapping window, we'll write the + // already calculated outputs to the main buffer passed to the process function and then process + // a new block. + main_input_ring_buffers: Vec>, + main_output_ring_buffers: Vec>, sidechain_ring_buffers: [Vec>; NUM_SIDECHAIN_INPUTS], - // To make this more convenient, we'll provide slices into the above buffers to the block - // process callback - main_block_buffer: Buffer<'static>, - sidechain_block_buffers: [Buffer<'static>; NUM_SIDECHAIN_INPUTS], + /// Results from the ring buffers are copied to this scratch buffer before being passed to the + /// plugin. Needed to handle overlap. + scratch_buffer: Vec, /// The current position in our ring buffers. Whenever this wraps around to 0, we'll process /// a block. @@ -36,36 +35,25 @@ impl StftHelper { /// Initialize the [`StftHelper`] for [`Buffer`]s with the specified number of channels and the /// given maximum block size. Call [`set_block_size()`][`Self::set_block_size()`] afterwards if /// you do not need the full capacity upfront. + /// + /// # Panics + /// + /// Panics if `num_channels == 0 || max_block_size == 0`. pub fn new(num_channels: usize, max_block_size: usize) -> Self { - nih_debug_assert_ne!(num_channels, 0); - nih_debug_assert_ne!(max_block_size, 0); + assert_ne!(num_channels, 0); + assert_ne!(max_block_size, 0); - let mut helper = Self { - main_ring_buffers: vec![vec![0.0; max_block_size]; num_channels], + Self { + main_input_ring_buffers: vec![vec![0.0; max_block_size]; num_channels], + main_output_ring_buffers: vec![vec![0.0; max_block_size]; num_channels], // Kinda hacky way to initialize an array of non-copy types sidechain_ring_buffers: [(); NUM_SIDECHAIN_INPUTS] .map(|_| vec![vec![0.0; max_block_size]; num_channels]), - main_block_buffer: Buffer::default(), - sidechain_block_buffers: [(); NUM_SIDECHAIN_INPUTS].map(|_| Buffer::default()), + scratch_buffer: vec![0.0; max_block_size], current_pos: 0, - }; - - // Preallocate the output slices. We'll point them to the ring buffers at the start of the - // process call. - unsafe { - helper.main_block_buffer.with_raw_vec(|main_block_slices| { - main_block_slices.resize_with(num_channels, || &mut []) - }); - for sidechain_block_buffer in &mut helper.sidechain_block_buffers { - sidechain_block_buffer.with_raw_vec(|main_block_slices| { - main_block_slices.resize_with(num_channels, || &mut []) - }); - } - }; - - helper + } } /// Change the current block size. This will clear the buffers, causing the next block to output @@ -75,12 +63,18 @@ impl StftHelper { /// /// WIll panic if `block_size > max_block_size`. pub fn set_block_size(&mut self, block_size: usize) { - assert!(block_size <= self.main_ring_buffers[0].capacity()); + assert!(block_size <= self.main_input_ring_buffers[0].capacity()); - for main_ring_buffer in &mut self.main_ring_buffers { + for main_ring_buffer in &mut self.main_input_ring_buffers { main_ring_buffer.resize(block_size, 0.0); main_ring_buffer.fill(0.0); } + for main_ring_buffer in &mut self.main_output_ring_buffers { + main_ring_buffer.resize(block_size, 0.0); + main_ring_buffer.fill(0.0); + } + self.scratch_buffer.resize(block_size, 0.0); + self.scratch_buffer.fill(0.0); for sidechain_ring_buffers in &mut self.sidechain_ring_buffers { for sidechain_ring_buffer in sidechain_ring_buffers { sidechain_ring_buffer.resize(block_size, 0.0); @@ -93,74 +87,56 @@ impl StftHelper { /// The amount of latency introduced when processing audio throug hthis [`StftHelper`]. pub fn latency_samples(&self) -> u32 { - self.main_ring_buffers[0].len() as u32 + self.main_input_ring_buffers[0].len() as u32 } - /// Process the audio in `main_buffer` and in any sidechain buffers in small blocks. Whenever a - /// new block is available, `process_cb()` gets called with a new audio block of the specified - /// side. The results written to the buffer will then be written back to `main_buffer` exactly - /// one block later, which means that this function will introduce one block of latency. This - /// can be compensated by calling - /// [`ProcessContext::set_latency()`][`crate::prelude::ProcessContext::set_latency()`] in your - /// plugin's initialization function. + /// Process the audio in `main_buffer` and in any sidechain buffers in small overlapping blocks + /// with a window function applied, adding up the results for the main buffer so they can be + /// written back to the host. Whenever a new block is available, `process_cb()` gets called with + /// a new audio block of the specified size with the windowing function already applied. The + /// summed reults will then be written back to `main_buffer` exactly one block later, which + /// means that this function will introduce one block of latency. This can be compensated by + /// calling [`ProcessContext::set_latency()`][`crate::prelude::ProcessContext::set_latency()`] + /// in your plugin's initialization function. + /// + /// For efficiency's sake this function will reuse the same vector for all calls to + /// `process_cb`. This means you can only access a single channel's worth of windowed data at a + /// time. The arguments to that function are `process_cb(channel_idx, sidechain_buffer_idx, + /// data)`, where `sidechain_buffer_idx` will be `None` for the main buffer. If there are any + /// sidechain buffers, then they will be processed before the main buffer. /// /// # Panics /// /// Panics if `main_buffer` or the buffers in `sidechain_buffers` do not have the same number of - /// channels as this [`StftHelper`]. + /// channels as this [`StftHelper`], if the sidechain buffers do not contain the same number of + /// samples as the main buffer, or if the window function does not match the block size. /// /// TODO: Maybe introduce a trait here so this can be used with things that aren't whole buffers /// TODO: And also introduce that aforementioned read-only process function (`analyze()?`) - pub fn process( + pub fn process_overlap_add( &mut self, main_buffer: &mut Buffer, sidechain_buffers: [&Buffer; NUM_SIDECHAIN_INPUTS], + window_function: &[f32], + overlap_times: usize, mut process_cb: F, ) where - F: FnMut(&mut Buffer, &[Buffer; NUM_SIDECHAIN_INPUTS]), + F: FnMut(usize, Option, &mut [f32]), { - assert_eq!(main_buffer.channels(), self.main_ring_buffers.len()); - - // Since the `StftHelper` object may move in between process calls, we need to make sure - // that these slices point to our ring buffers at the start of each call - unsafe { - self.main_block_buffer.with_raw_vec(|main_block_slices| { - assert_eq!(main_block_slices.len(), self.main_ring_buffers.len()); - for (channel_idx, channel_slice) in main_block_slices.iter_mut().enumerate() { - // SAFETY: This is equivalent to splitting on each channel, and these block - // slices will only be used here as part of the callback when the ring - // buffers are not mutably borrwed - *channel_slice = - &mut *(self.main_ring_buffers[channel_idx].as_mut_slice() as *mut _); - } - }); - for (sidechain_block_buffer, sidechain_ring_buffer) in self - .sidechain_block_buffers - .iter_mut() - .zip(self.sidechain_ring_buffers.iter_mut()) - { - sidechain_block_buffer.with_raw_vec(|sidechain_block_slices| { - assert_eq!(sidechain_block_slices.len(), sidechain_ring_buffer.len()); - for (channel_idx, channel_slice) in - sidechain_block_slices.iter_mut().enumerate() - { - *channel_slice = - &mut *(sidechain_ring_buffer[channel_idx].as_mut_slice() as *mut _); - } - }); - } - }; + assert_eq!(main_buffer.channels(), self.main_input_ring_buffers.len()); + assert_eq!(window_function.len(), self.main_input_ring_buffers[0].len()); // We'll copy samples from `*_buffer` into `*_ring_buffers` while simultaneously copying // already processed samples from `main_ring_buffers` in into `main_buffer` let main_buffer_len = main_buffer.len(); let num_channels = main_buffer.channels(); - let block_len = self.main_ring_buffers[0].len(); + let block_size = self.main_input_ring_buffers[0].len(); + let window_interval = block_size / overlap_times; let mut already_processed_samples = 0; while already_processed_samples < main_buffer_len { let remaining_samples = main_buffer_len - already_processed_samples; - let samples_until_next_block = block_len - self.current_pos; - let samples_to_process = samples_until_next_block.min(remaining_samples); + let samples_until_next_window = (window_interval - self.current_pos) % window_interval; + let samples_to_process = samples_until_next_window.min(remaining_samples); // Copy the input from `main_buffer` to the ring buffer while copying last block's // result from the buffer to `main_buffer` @@ -175,12 +151,20 @@ impl StftHelper { .get_unchecked_mut(channel_idx) .get_unchecked_mut(already_processed_samples + sample_offset) }; - let ring_buffer_sample = unsafe { - self.main_ring_buffers + let input_ring_buffer_sample = unsafe { + self.main_input_ring_buffers .get_unchecked_mut(channel_idx) .get_unchecked_mut(self.current_pos + sample_offset) }; - mem::swap(sample, ring_buffer_sample); + let output_ring_buffer_sample = unsafe { + self.main_output_ring_buffers + .get_unchecked_mut(channel_idx) + .get_unchecked_mut(self.current_pos + sample_offset) + }; + *input_ring_buffer_sample = *sample; + *sample = *output_ring_buffer_sample; + // Very important, or else we'll overlap-add ourselves into a feedback hell + *output_ring_buffer_sample = 0.0; } } @@ -208,16 +192,88 @@ impl StftHelper { } } - already_processed_samples += samples_to_process; - self.current_pos += samples_to_process; - // At this point we either have `already_processed_samples == main_buffer_len`, or - // `self.current_pos == block_len`. If it's the latter, then we can process a new block. - if self.current_pos == block_len { - process_cb(&mut self.main_block_buffer, &self.sidechain_block_buffers); + // `self.current_pos % window_interval == 0`. If it's the latter, then we can process a + // new block. + if samples_to_process == samples_until_next_window { + // Because we're processing in smaller windows, the input ring buffers sadly does + // not always contain the full contiguous range we're interested in because they map + // wrap around. Because premade FFT algorithms typically can't handle this, we'll + // start with copying - self.current_pos = 0; + // TODO: Sdiechain + + for (channel_idx, (input_ring_buffer, output_ring_buffer)) in self + .main_input_ring_buffers + .iter() + .zip(self.main_output_ring_buffers.iter_mut()) + .enumerate() + { + copy_ring_to_scratch_buffer( + &mut self.scratch_buffer, + self.current_pos, + input_ring_buffer, + ); + multiply_scratch_buffer(&mut self.scratch_buffer, window_function); + process_cb(channel_idx, None, &mut self.scratch_buffer); + + // The actual overlap-add part of the equation + add_scratch_to_ring_buffer( + &self.scratch_buffer, + self.current_pos, + output_ring_buffer, + ); + } } + + // Do this after handling the block or else we'll copy the wrong samples. + already_processed_samples += samples_to_process; + self.current_pos = (self.current_pos + samples_to_process) % block_size; } } } + +/// Copy data from the the specified ring buffer (borrowed from `self`) to the scratch buffers at +/// the current position. This is a free function because you cannot pass an immutable reference to +/// a field from `&self` to a `&mut self` method. +#[inline] +fn copy_ring_to_scratch_buffer( + scratch_buffer: &mut [f32], + current_pos: usize, + ring_buffer: &[f32], +) { + let block_size = scratch_buffer.len(); + let num_copy_before_wrap = block_size - current_pos; + scratch_buffer[0..num_copy_before_wrap].copy_from_slice(&ring_buffer[current_pos..block_size]); + scratch_buffer[num_copy_before_wrap..block_size].copy_from_slice(&ring_buffer[0..current_pos]); +} + +/// Multiply the scratch buffer by some window function. Also free function because you can't do +/// split borrows with methods. +#[inline] +fn multiply_scratch_buffer(scratch_buffer: &mut [f32], window_function: &[f32]) { + for (sample, window_sample) in scratch_buffer.iter_mut().zip(window_function) { + *sample *= window_sample; + } +} + +/// Add data from the scratch buffer to the specified ring buffer. When writing samples from this +/// ring buffer back to the host's outputs they must be cleared to prevent infinite feedback. +#[inline] +fn add_scratch_to_ring_buffer(scratch_buffer: &[f32], current_pos: usize, ring_buffer: &mut [f32]) { + // TODO: This could also use some SIMD + let block_size = scratch_buffer.len(); + let num_copy_before_wrap = block_size - current_pos; + for (scratch_sample, ring_sample) in scratch_buffer[0..num_copy_before_wrap] + .iter() + .zip(&mut ring_buffer[current_pos..block_size]) + { + *ring_sample += *scratch_sample; + } + for (scratch_sample, ring_sample) in scratch_buffer[num_copy_before_wrap..block_size] + .iter() + .zip(&mut ring_buffer[0..current_pos]) + { + *ring_sample += *scratch_sample; + } +}