”工欲善其事,必先利其器。“—孔子《论语.录灵公》
首页 > 编程 > Tensorflow 音乐预测

Tensorflow 音乐预测

发布于2024-11-08
浏览:592

Tensorflow music prediction

在本文中,我将展示如何使用张量流来预测音乐风格。
在我的示例中,我比较了电子音乐和古典音乐。

您可以在我的github上找到代码:
https://github.com/victordalet/sound_to_partition


I - 数据集

第一步,您需要创建一个数据集文件夹,并在里面添加一个音乐风格文件夹,例如我添加一个 techno 文件夹和 classic 文件夹,其中放置我的 wav 歌曲。

II - 火车

我创建一个训练文件,参数 max_epochs 需要完成。

修改构造函数中与数据集文件夹中的目录对应的类。

在加载和处理方法中,我从不同的目录检索wav文件并获取频谱图。

出于训练目的,我使用 Keras 卷积和模型。

import os
import sys
from typing import List

import librosa
import numpy as np
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.image import resize



class Train:

    def __init__(self):
        self.X_train = None
        self.X_test = None
        self.y_train = None
        self.y_test = None
        self.data_dir: str = 'dataset'
        self.classes: List[str] = ['techno','classic']
        self.max_epochs: int = int(sys.argv[1])

    @staticmethod
    def load_and_preprocess_data(data_dir, classes, target_shape=(128, 128)):
        data = []
        labels = []

        for i, class_name in enumerate(classes):
            class_dir = os.path.join(data_dir, class_name)
            for filename in os.listdir(class_dir):
                if filename.endswith('.wav'):
                    file_path = os.path.join(class_dir, filename)
                    audio_data, sample_rate = librosa.load(file_path, sr=None)
                    mel_spectrogram = librosa.feature.melspectrogram(y=audio_data, sr=sample_rate)
                    mel_spectrogram = resize(np.expand_dims(mel_spectrogram, axis=-1), target_shape)
                    data.append(mel_spectrogram)
                    labels.append(i)

        return np.array(data), np.array(labels)

    def create_model(self):
        data, labels = self.load_and_preprocess_data(self.data_dir, self.classes)
        labels = to_categorical(labels, num_classes=len(self.classes))  # Convert labels to one-hot encoding
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(data, labels, test_size=0.2,
                                                                                random_state=42)

        input_shape = self.X_train[0].shape
        input_layer = Input(shape=input_shape)
        x = Conv2D(32, (3, 3), activation='relu')(input_layer)
        x = MaxPooling2D((2, 2))(x)
        x = Conv2D(64, (3, 3), activation='relu')(x)
        x = MaxPooling2D((2, 2))(x)
        x = Flatten()(x)
        x = Dense(64, activation='relu')(x)
        output_layer = Dense(len(self.classes), activation='softmax')(x)
        self.model = Model(input_layer, output_layer)

        self.model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

    def train_model(self):
        self.model.fit(self.X_train, self.y_train, epochs=self.max_epochs, batch_size=32,
                       validation_data=(self.X_test, self.y_test))
        test_accuracy = self.model.evaluate(self.X_test, self.y_test, verbose=0)
        print(test_accuracy[1])

    def save_model(self):
        self.model.save('weight.h5')


if __name__ == '__main__':
    train = Train()
    train.create_model()
    train.train_model()
    train.save_model()

III-测试

为了测试和使用该模型,我创建了此类来检索权重并预测音乐的风格。

不要忘记将正确的类添加到构造函数中。

from typing import List

import librosa
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.image import resize
import tensorflow as tf



class Test:

    def __init__(self, audio_file_path: str):
        self.model = load_model('weight.h5')
        self.target_shape = (128, 128)
        self.classes: List[str] = ['techno','classic']
        self.audio_file_path: str = audio_file_path

    def test_audio(self, file_path, model):
        audio_data, sample_rate = librosa.load(file_path, sr=None)
        mel_spectrogram = librosa.feature.melspectrogram(y=audio_data, sr=sample_rate)
        mel_spectrogram = resize(np.expand_dims(mel_spectrogram, axis=-1), self.target_shape)
        mel_spectrogram = tf.reshape(mel_spectrogram, (1,)   self.target_shape   (1,))

        predictions = model.predict(mel_spectrogram)

        class_probabilities = predictions[0]

        predicted_class_index = np.argmax(class_probabilities)

        return class_probabilities, predicted_class_index

    def test(self):
        class_probabilities, predicted_class_index = self.test_audio(self.audio_file_path, self.model)

        for i, class_label in enumerate(self.classes):
            probability = class_probabilities[i]
            print(f'Class: {class_label}, Probability: {probability:.4f}')

        predicted_class = self.classes[predicted_class_index]
        accuracy = class_probabilities[predicted_class_index]
        print(f'The audio is classified as: {predicted_class}')
        print(f'Accuracy: {accuracy:.4f}')
版本声明 本文转载于:https://dev.to/victordalet/tensorflow-music-prediction-4i6f?1如有侵犯,请联系[email protected]删除
最新教程 更多>
  • 为什么Microsoft Visual C ++无法正确实现两台模板的实例?
    为什么Microsoft Visual C ++无法正确实现两台模板的实例?
    [2明确担心Microsoft Visual C(MSVC)在正确实现两相模板实例化方面努力努力。该机制的哪些具体方面无法按预期运行?背景:说明:的初始Syntax检查在范围中受到限制。它未能检查是否存在声明名称的存在,导致名称缺乏正确的声明时会导致编译问题。为了说明这一点,请考虑以下示例:一个符合...
    编程 发布于2025-02-19
  • 如何可靠地检查MySQL表中的列存在?
    如何可靠地检查MySQL表中的列存在?
    在mySQL中确定列中的列存在,验证表中的列存在与与之相比有点困惑其他数据库系统。常用的方法:如果存在(从信息_schema.columns select * * where table_name ='prefix_topic'和column_name =&...
    编程 发布于2025-02-19
  • Java是否允许多种返回类型:仔细研究通用方法?
    Java是否允许多种返回类型:仔细研究通用方法?
    在java中的多个返回类型:一个误解介绍,其中foo是自定义类。该方法声明似乎拥有两种返回类型:列表和E。但是,情况确实如此吗?通用方法:拆开神秘 [方法仅具有单一的返回类型。相反,它采用机制,如钻石符号“ ”。分解方法签名: :本节定义了一个通用类型参数,E。它表示该方法接受扩展FOO类的任何...
    编程 发布于2025-02-19
  • 为什么使用固定定位时,为什么具有100%网格板柱的网格超越身体?
    为什么使用固定定位时,为什么具有100%网格板柱的网格超越身体?
    网格超过身体,用100%grid-template-columns 问题:考虑以下CSS和HTML: position:fixed; grid-template-columns:40%60%; grid-gap:5px; 背景:#eee; 当位置未固定时,网格将正确显示。但是,当...
    编程 发布于2025-02-19
  • 在没有密码提示的情况下,如何在Ubuntu上安装MySQL?
    在没有密码提示的情况下,如何在Ubuntu上安装MySQL?
    在ubuntu 使用debconf-set-selections 在安装过程中避免密码提示mysql root用户。这需要以下步骤: sudo debconf-set-selections
    编程 发布于2025-02-19
  • 如何使用组在MySQL中旋转数据?
    如何使用组在MySQL中旋转数据?
    在关系数据库中使用mysql组使用mysql组来调整查询结果。在这里,我们面对一个共同的挑战:使用组的组将数据从基于行的基于列的基于列的转换。通过子句以及条件汇总函数,例如总和或情况。让我们考虑以下查询: select d.data_timestamp, sum(data_id = 1 tata...
    编程 发布于2025-02-19
  • \“(1)vs.(;;):编译器优化是否消除了性能差异?\”
    \“(1)vs.(;;):编译器优化是否消除了性能差异?\”
    答案:在大多数现代编译器中,while(1)和(1)和(;;)之间没有性能差异。 说明: perl: S-> 7 8 unstack v-> 4 -e语法ok 在GCC中,两者都循环到相同的汇编代码中,如下所示:。 globl t_时 t_时: .l2: movl $ .lc0,�i ...
    编程 发布于2025-02-19
  • 为什么箭头函数在IE11中引起语法错误?如何修复它们?
    为什么箭头函数在IE11中引起语法错误?如何修复它们?
    为什么arrow functions在IE 11 中引起语法错误。 IE 11不支持箭头函数,导致语法错误。这使用传统函数语法来定义与原始箭头函数相同的逻辑。 IE 11现在将正确识别并执行代码。
    编程 发布于2025-02-19
  • 'exec()
    'exec()
    Exec对本地变量的影响: exec function,python staple,用于动态代码执行的python staple,提出一个有趣的Query:它可以在函数中更新局部变量吗?在Python 3中,以下代码代码无法更新本地变量,如人们所期望的:代替预期的'3',它令人震...
    编程 发布于2025-02-19
  • 如何从Python中的字符串中删除表情符号:固定常见错误的初学者指南?
    如何从Python中的字符串中删除表情符号:固定常见错误的初学者指南?
    从python 导入编解码器 导入 text = codecs.decode('这狗\ u0001f602'.encode('utf-8'),'utf-8') 印刷(文字)#带有表情符号 emoji_pattern = re.compile(“ [”...
    编程 发布于2025-02-19
  • 如何检查对象是否具有Python中的特定属性?
    如何检查对象是否具有Python中的特定属性?
    方法来确定对象属性存在寻求一种方法来验证对象中特定属性的存在。考虑以下示例,其中尝试访问不确定属性会引起错误: >>> a = someClass() >>> A.property Trackback(最近的最新电话): 文件“ ”,第1行, AttributeError:SomeClass实...
    编程 发布于2025-02-19
  • 如何在JavaScript对象中动态设置键?
    如何在JavaScript对象中动态设置键?
    如何为JavaScript对象变量创建动态键,尝试为JavaScript对象创建动态键,使用此Syntax jsObj['key' i] = 'example' 1;将不起作用。正确的方法采用方括号:他们维持一个长度属性,该属性反映了数字属性(索引)和一个数字属性的数量。标准对象没有模仿这...
    编程 发布于2025-02-19
  • 版本5.6.5之前,使用current_timestamp与时间戳列的current_timestamp与时间戳列有什么限制?
    版本5.6.5之前,使用current_timestamp与时间戳列的current_timestamp与时间戳列有什么限制?
    在默认值中使用current_timestamp或mysql版本中的current_timestamp或在5.6.5 这种限制源于遗产实现的关注,这些限制需要为Current_timestamp功能提供特定的实现。消息和相关问题 current_timestamp值: 创建表`foo`( `...
    编程 发布于2025-02-19
  • 如何使用PHP从XML文件中有效地检索属性值?
    如何使用PHP从XML文件中有效地检索属性值?
    从php 您的目标可能是检索“ varnum”属性值,其中提取数据的传统方法可能会使您留下PHP陷入困境。使用simplexmlelement :: attributes()函数提供了简单的解决方案。此函数可访问对XML元素作为关联数组的属性: - > attributes()为$ attr...
    编程 发布于2025-02-19
  • 哪种方法更有效地用于点 - 填点检测:射线跟踪或matplotlib \的路径contains_points?
    哪种方法更有效地用于点 - 填点检测:射线跟踪或matplotlib \的路径contains_points?
    在Python 射线tracing方法 matplotlib路径对象表示多边形。它检查给定点是否位于定义路径内。 This function is often faster than the ray tracing approach, as seen in the code snippet pr...
    编程 发布于2025-02-19

免责声明: 提供的所有资源部分来自互联网,如果有侵犯您的版权或其他权益,请说明详细缘由并提供版权或权益证明然后发到邮箱:[email protected] 我们会第一时间内为您处理。

Copyright© 2022 湘ICP备2022001581号-3