Source code for midiogre.core.compositions

"""Composition module for chaining multiple MIDI transforms.

This module provides functionality to compose multiple MIDI transforms into a single
transform pipeline. The transforms are applied sequentially in the order they are provided.

Example:
    >>> from midiogre.augmentations import PitchShift, OnsetTimeShift
    >>> from midiogre.core import Compose
    >>> 
    >>> # Create a transform pipeline
    >>> transform = Compose([
    ...     PitchShift(max_shift=2, p=0.5),
    ...     OnsetTimeShift(max_shift=0.1, p=0.3)
    ... ])
    >>> 
    >>> # Apply transforms to MIDI data
    >>> transformed_midi = transform(midi_data)
"""

from pretty_midi import PrettyMIDI


[docs] class Compose: """A class for composing multiple MIDI transforms into a single transform. This class allows you to chain multiple transforms together and apply them sequentially to MIDI data. Each transform in the pipeline must be a callable that takes a PrettyMIDI object as input and returns a transformed PrettyMIDI object. Args: transforms (list or tuple): A sequence of MIDI transforms to be applied in order. Each transform should be a callable that takes a PrettyMIDI object as input and returns a transformed PrettyMIDI object. Raises: TypeError: If transforms is not a list or tuple. Example: >>> # Create transforms >>> pitch_shift = PitchShift(max_shift=2) >>> onset_shift = OnsetTimeShift(max_shift=0.1) >>> >>> # Compose transforms >>> transform = Compose([pitch_shift, onset_shift]) >>> >>> # Apply to MIDI data >>> transformed_midi = transform(midi_data) """
[docs] def __init__(self, transforms: list or tuple): """ Compose several MIDIOgre transforms together. :param transforms: list of MIDIOgre transforms to be performed in the given order """ if not (isinstance(transforms, list) or isinstance(transforms, tuple)): raise TypeError( "Transforms to be composed must be wrapped in a list or a tuple, got {}".format(type(transforms)) ) self.transforms = transforms
[docs] def __len__(self): """Return the number of transforms in the composition. Returns: int: Number of transforms in the pipeline. """ return len(self.transforms)
[docs] def __call__(self, midi_data: PrettyMIDI): """Apply all transforms sequentially to the MIDI data. Args: midi_data (PrettyMIDI): The MIDI data to transform. Returns: PrettyMIDI: The transformed MIDI data after applying all transforms in sequence. """ for transform in self.transforms: midi_data = transform(midi_data) return midi_data