利用 Android 手机和 YAMNet ML 模型进行声音分类(二)

文 / MLGDE George Soloupis

这是教程的第 2 部分,介绍了如何利用出色的 YAMNet 机器学习 模型将手机麦克风录制的声音分类为 500 多种类别(第 1 部分)。

我们已在上一部分中解释过模型架构,对模型进行了基准测试与格式转换,最后得到一个 tflite 文件,该文件可在 TensorFlow Hub 中下载并在手机内使用。该模型文件没有 元数据,因此应用只能使用解释器进行推断。

步骤如下:

  1. 将手机麦克风录下的声音转换为浮点数组。
  2. 然后将数组传递至解释器,该解释器使用的是储存在 assets 文件夹中的模型文件。
  3. 解释器将生成 3 种输出。得分、嵌入向量和声谱图。
  4. 输出的得分可用于获取推断的前 K 个类。
  5. 位于前列的类将显示在屏幕上。

您可以在该 GitHub 代码库 中找到相应的代码。以及模型和可执行 Colab Notebook 的相关信息,在该笔记本中,您可以借助 Tensorflow 和 Tensorflow Lite 解释器,使用 .wav 文件运行推断。

步入正题

收集声音的过程很简单。您可以使用 AudioRecord 类开始录制声音并生成 ByteArrayOutputStream 和 ArrayList。

fun startRecording() {
        mRecorder = AudioRecord.Builder().setAudioSource(AUDIO_SOURCE)
            .setAudioFormat(AUDIO_FORMAT)
            .setBufferSizeInBytes(BUFFER_SIZE)
            .build()

        done = false
        mRecording = true
        mRecorder?.startRecording()
        mThread = Thread(readAudio)
        mThread!!.start()
}

请注意我们提高麦克风增益的方式:

private val readAudio = Runnable {
        var readBytes: Int
        buffer = ShortArray(BUFFER_SIZE)
        while (mRecording) {
            readBytes = mRecorder!!.read(buffer, 0, BUFFER_SIZE)

            //Higher volume of microphone
            //https://stackoverflow.com/questions/25441166/how-to-adjust-microphone-sensitivity-while-recording-audio-in-android
            if (readBytes > 0) {
                for (i in 0 until readBytes) {
                    buffer[i] = Math.min(
                        (buffer[i] * 6.7).toInt(),
                        Short.MAX_VALUE.toInt()
                    ).toShort()
                }
            }
            if (readBytes != AudioRecord.ERROR_INVALID_OPERATION) {
                for (s in buffer) {

                    // Add all values to arraylist
                    bufferForInference.add(s)

                    writeShort(mPcmStream, s)
                }
            }
        }
}

您可以在 ListeningRecorder 类中找到所有声音实现的合集。

必须将输入标准化为 -1 和 1 之间的浮点数。为进行标准化,我们只需要将所有值除以 2**16,而我们的代码中则为除以 32768(对于 16 位整数而言,其范围为 -32k 到 +32k)。

val floatsForInference = FloatArray(arrayListShorts.size)
for ((index, value) in arrayListShorts.withIndex()) {
    floatsForInference[index] = (value / 32768F)
}

得到的 FloatArray 将传递至 YamnetModelExecutor 类,而推断会在 execute 函数中完成。

fun execute(floatsInput: FloatArray): Pair<ArrayList<String>, ArrayList<Float>> {

        predictTime = System.currentTimeMillis()
        val inputSize = floatsInput.size // ~2 seconds of sound
        //Log.i("YAMNET_INPUT_SIZE", inputSize.toString())

        val inputValues = floatsInput//FloatArray(inputSize)

        val inputs = arrayOf<Any>(inputValues)
        val outputs = HashMap<Int, Any>()

        // Outputs of yamnet model with tflite and for 2 seconds .wav file
        // scores(4, 521) emmbedings(4, 1024) spectogram(240, 64)
        val arrayScores = Array(4) { FloatArray(521) { 0f } }
        val arrayEmbeddings = Array(4) { FloatArray(1024) { 0f } }
        val arraySpectograms = Array(240) { FloatArray(64) { 0f } }

        outputs[0] = arrayScores
        outputs[1] = arrayEmbeddings
        outputs[2] = arraySpectograms

        try {
            interpreter.runForMultipleInputsOutputs(inputs, outputs)
        } catch (e: Exception) {
            Log.e("EXCEPTION", e.toString())
        }

        ..............................
        ..............................

        return Pair(finalListOfOutputs, listOfMaximumValues) // ArrayList<String>
    }

将手机调整为重复收集时长为 2 秒的声音。在 YAMNet 模型的第一部分中,模型会将输入特征分帧成具有 50% 重叠且长度为 0.96 秒的示例。因此在我们的示例模型中输出了 4 个得分数组。然后求坐标轴 0 处的平均值。

val arrayMeanScores = FloatArray(521) { 0f }
        for (i in 0 until 521) {
            // 求这 4 个数组在坐标轴 0 处的平均值
            arrayMeanScores[i] = arrayListOf(
                arrayScores[0][i],
                arrayScores[1][i],
                arrayScores[2][i],
                arrayScores[3][i]
            ).average().toFloat()
        }

模型类以 .txt 文件的形式提供,并存储在 assets 文件夹中。您可以使用 TensorFlow 支持库 轻松将其转换为数组列表。

将下方代码添加至 build.gradle 文件:

implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly-SNAPSHOT')

获取类的数组列表:

val labels = FileUtil.loadLabels(context, "classes.txt")

打印文本文件的第一个值:

Speech
Child speech, kid speaking
Conversation
Narration, monologue
Babbling
Speech synthesizer
Shout
Bellow
Whoop
Yell
Children shouting
Screaming
Whispering
Laughter
Baby laughter
Giggle
Snicker
Belly laugh
………………………

得到 521 个类平均分数的数组列表后,我们找出了前 10 个类及其标签:

fun execute(floatsInput: FloatArray): Pair<ArrayList<String>, ArrayList<Float>> {

        .................................
        .................................
        
        val listOfArrayMeanScores = arrayMeanScores.toCollection(ArrayList())

        val listOfMaximumValues = arrayListOf<Float>()
        for (i in 0 until 10) {
            val number = listOfArrayMeanScores.max() ?: 0f
            listOfMaximumValues.add(number)
            listOfArrayMeanScores.remove(number)
        }

        val listOfMaxIndices = arrayListOf<Int>()
        for (i in 0 until 10) {
            for (k in arrayMeanScores.indices) {
                if (listOfMaximumValues[i] == arrayMeanScores[k]) {
                    listOfMaxIndices.add(k)
                }
            }

        }
        val finalListOfOutputs = arrayListOf<String>()
        for (i in listOfMaxIndices.indices) {
            finalListOfOutputs.add(labels.get(listOfMaxIndices.get(i)))
        }

        return Pair(finalListOfOutputs, listOfMaximumValues) // ArrayList<String>
    }

所以最后我们得到了前 10 个类对应的概率和名称。随后这些值将被传递至主界面,并在屏幕上显示(由于限制,屏幕将只能显示 5 个类):

您可以在 视频 中了解如何使用本应用,访问 项目链接 了解了解详情。

该项目用 Kotlin 语言编写,并且:

  • 使用 TensorFlow 支持库
  • 使用 TensorFlow Lite 解释器及:
  • 数据绑定
  • 支持协程的 MVVM
  • Koin DI

未来计划

  • 对音频录制进行调优,让录制时长小于 2 秒,并寻找能够获得具有较高准确性且能取得更快结果的最佳时长。
  • 向 tflite 文件添加元数据 (Metadata),以结合 ML 绑定一同使用

至此,本教程结束。希望您阅读愉快,并使用 TensorFlow Lite 将所学知识运用到实际应用中。访问 TensorFlow Hub,获取多样的模型文件!了解更多信息或者想要 贡献内容

感谢 Sayak Paul 和 Le Viet Gia Khanh 的评审和支持。

关于作者

George Soloupis,我从一名药剂师转行成为了 Android 开发工程师。目前积极活跃于 Google 的移动操作系统 TensorFlow Lite 机器学习工作组。

原文:Classification of sounds using android mobile phone and the YAMNet ML model
中文:TensorFlow 公众号