| | --- |
| | language: en |
| | license: mit |
| | tags: |
| | - audio |
| | - audio-classification |
| | - musical-instruments |
| | - wav2vec2 |
| | - transformers |
| | - pytorch |
| | datasets: |
| | - custom |
| | metrics: |
| | - accuracy |
| | - roc_auc |
| | model-index: |
| | - name: epoch_musical_instruments_identification_2 |
| | results: |
| | - task: |
| | type: audio-classification |
| | name: Musical Instrument Classification |
| | metrics: |
| | - type: accuracy |
| | value: 0.9333 |
| | name: Accuracy |
| | - type: roc_auc |
| | value: 0.9859 |
| | name: ROC AUC (Macro) |
| | - type: loss |
| | value: 1.0639 |
| | name: Validation Loss |
| | base_model: |
| | - facebook/wav2vec2-base-960h |
| | --- |
| | |
| | # Musical Instrument Classification Model |
| |
|
| | This model is a fine-tuned version of [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) for musical instrument classification. It can identify 9 different musical instruments from audio recordings with high accuracy. |
| |
|
| | ## Model Description |
| |
|
| | - **Model type:** Audio Classification |
| | - **Base model:** facebook/wav2vec2-base-960h |
| | - **Language:** Audio (no specific language) |
| | - **License:** MIT |
| | - **Fine-tuned on:** Custom musical instrument dataset (200 samples for each class) |
| |
|
| | ## Performance |
| |
|
| | The model achieves excellent performance on the evaluation set after 5 epochs of training: |
| |
|
| | - **Final Accuracy:** 93.33% |
| | - **Final ROC AUC (Macro):** 98.59% |
| | - **Final Validation Loss:** 1.064 |
| | - **Evaluation Runtime:** 14.18 seconds |
| | - **Evaluation Speed:** 25.39 samples/second |
| |
|
| | ### Training Progress |
| |
|
| | | Epoch | Training Loss | Validation Loss | ROC AUC | Accuracy | |
| | |-------|---------------|-----------------|---------|----------| |
| | | 1 | 1.9872 | 1.8875 | 0.9248 | 0.6639 | |
| | | 2 | 1.8652 | 1.4793 | 0.9799 | 0.8000 | |
| | | 3 | 1.3868 | 1.2311 | 0.9861 | 0.8194 | |
| | | 4 | 1.3242 | 1.1121 | 0.9827 | 0.9250 | |
| | | 5 | 1.1869 | 1.0639 | 0.9859 | 0.9333 | |
| |
|
| | ## Supported Instruments |
| |
|
| | The model can classify the following 9 musical instruments: |
| |
|
| | 1. **Acoustic Guitar** |
| | 2. **Bass Guitar** |
| | 3. **Drum Set** |
| | 4. **Electric Guitar** |
| | 5. **Flute** |
| | 6. **Hi-Hats** |
| | 7. **Keyboard** |
| | 8. **Trumpet** |
| | 9. **Violin** |
| |
|
| | ## Usage |
| |
|
| | ### Quick Start with Pipeline |
| |
|
| | ```python |
| | from transformers import pipeline |
| | import torchaudio |
| | |
| | # Load the classification pipeline |
| | classifier = pipeline("audio-classification", model="Bhaveen/epoch_musical_instruments_identification_2") |
| | |
| | # Load and preprocess audio |
| | audio, rate = torchaudio.load("your_audio_file.wav") |
| | transform = torchaudio.transforms.Resample(rate, 16000) |
| | audio = transform(audio).numpy().reshape(-1)[:48000] |
| | |
| | # Classify the audio |
| | result = classifier(audio) |
| | print(result) |
| | ``` |
| |
|
| | ### Using Transformers Directly |
| |
|
| | ```python |
| | from transformers import AutoFeatureExtractor, AutoModelForAudioClassification |
| | import torchaudio |
| | import torch |
| | |
| | # Load model and feature extractor |
| | model_name = "Bhaveen/epoch_musical_instruments_identification_2" |
| | feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) |
| | model = AutoModelForAudioClassification.from_pretrained(model_name) |
| | |
| | # Load and preprocess audio |
| | audio, rate = torchaudio.load("your_audio_file.wav") |
| | transform = torchaudio.transforms.Resample(rate, 16000) |
| | audio = transform(audio).numpy().reshape(-1)[:48000] |
| | |
| | # Extract features and make prediction |
| | inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt") |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| | predicted_class = torch.argmax(predictions, dim=-1) |
| | |
| | print(f"Predicted instrument: {model.config.id2label[predicted_class.item()]}") |
| | ``` |
| |
|
| | ## Training Details |
| |
|
| | ### Dataset and Preprocessing |
| |
|
| | - **Custom dataset** with audio recordings of 9 musical instruments |
| | - **Train/Test Split:** 80/20 using file numbering (files < 160 for training) |
| | - **Data Balancing:** Random oversampling applied to minority classes |
| | - **Audio Preprocessing:** |
| | - Resampling to 16,000 Hz |
| | - Fixed length of 48,000 samples (3 seconds) |
| | - Truncation of longer audio files |
| |
|
| | ### Training Configuration |
| |
|
| | ```python |
| | # Training hyperparameters |
| | batch_size = 1 |
| | gradient_accumulation_steps = 4 |
| | learning_rate = 5e-6 |
| | num_train_epochs = 5 |
| | warmup_steps = 50 |
| | weight_decay = 0.02 |
| | ``` |
| |
|
| | ### Model Architecture |
| |
|
| | - **Base Model:** facebook/wav2vec2-base-960h |
| | - **Classification Head:** Added for 9-class classification |
| | - **Parameters:** ~95M trainable parameters |
| | - **Features:** Wav2Vec2 audio representations with fine-tuned classification layer |
| |
|
| | ## Technical Specifications |
| |
|
| | - **Audio Format:** WAV files |
| | - **Sample Rate:** 16,000 Hz |
| | - **Input Length:** 3 seconds (48,000 samples) |
| | - **Model Framework:** PyTorch + Transformers |
| | - **Inference Device:** GPU recommended (CUDA) |
| |
|
| | ## Evaluation Metrics |
| |
|
| | The model uses the following evaluation metrics: |
| |
|
| | - **Accuracy:** Standard classification accuracy |
| | - **ROC AUC:** Macro-averaged ROC AUC with one-vs-rest approach |
| | - **Multi-class Classification:** Softmax probabilities for all 9 instrument classes |
| |
|
| |
|
| |
|
| | ## Limitations and Considerations |
| |
|
| | 1. **Audio Duration:** Model expects exactly 3-second audio clips (truncates longer, may not work well with shorter) |
| | 2. **Single Instrument Focus:** Optimized for single instrument classification, mixed instruments may produce uncertain results |
| | 3. **Audio Quality:** Performance depends on audio quality and recording conditions |
| | 4. **Sample Rate:** Input must be resampled to 16kHz for optimal performance |
| | 5. **Domain Specificity:** Trained on specific instrument recordings, may not generalize to all variants or playing styles |
| |
|
| | ## Training Environment |
| |
|
| | - **Platform:** Google Colab |
| | - **GPU:** CUDA-enabled device |
| | - **Libraries:** |
| | - transformers==4.28.1 |
| | - torchaudio==0.12 |
| | - datasets |
| | - evaluate |
| | - imblearn |
| |
|
| | ## Model Files |
| |
|
| | The repository contains: |
| | - Model weights and configuration |
| | - Feature extractor configuration |
| | - Training logs and metrics |
| | - Label mappings (id2label, label2id) |
| |
|
| | --- |
| |
|
| | *Model trained as part of a hackathon project* |