Adjustment of DeepSpeech to train 8 kHz model

I’d like to try to train 8kHz model trained directly on 8kHz data to compare its performance to model with 8kHz data upsampled to 16kHz. In order to achieve it I need to allow DeepSpeech to use 8 kHz data directly for training (and avoid upsampling).

The modification is here below (from git diff):

diff --git a/examples/mic_vad_streaming/mic_vad_streaming.py b/examples/mic_vad_streaming/mic_vad_streaming.py
index 6e7f499..a1e1907 100755
--- a/examples/mic_vad_streaming/mic_vad_streaming.py
+++ b/examples/mic_vad_streaming/mic_vad_streaming.py
@@ -16,7 +16,7 @@ class Audio(object):
 
     FORMAT = pyaudio.paInt16
     # Network/VAD rate-space
-    RATE_PROCESS = 16000
+    RATE_PROCESS = 8000
     CHANNELS = 1
     BLOCKS_PER_SECOND = 50
 
@@ -66,7 +66,7 @@ class Audio(object):
         return resample16.tostring()
 
     def read_resampled(self):
-        """Return a block of audio data resampled to 16000hz, blocking if necessary."""
+        """Return a block of audio data resampled to 8000Hz, blocking if necessary."""
         return self.resample(data=self.buffer_queue.get(),
                              input_rate=self.input_rate)
 
@@ -189,7 +189,7 @@ def main(ARGS):
 
 if __name__ == '__main__':
     BEAM_WIDTH = 500
-    DEFAULT_SAMPLE_RATE = 16000
+    DEFAULT_SAMPLE_RATE = 8000
     LM_ALPHA = 0.75
     LM_BETA = 1.85
     N_FEATURES = 26
diff --git a/examples/vad_transcriber/wavTranscriber.py b/examples/vad_transcriber/wavTranscriber.py
index 2735879..7671121 100644
--- a/examples/vad_transcriber/wavTranscriber.py
+++ b/examples/vad_transcriber/wavTranscriber.py
@@ -46,7 +46,7 @@ Returns a list [Inference, Inference Time, Audio Length]
 '''
 def stt(ds, audio, fs):
     inference_time = 0.0
-    audio_length = len(audio) * (1 / 16000)
+    audio_length = len(audio) * (1 / 8000)
 
     # Run Deepspeech
     logging.debug('Running inference...')
@@ -94,7 +94,7 @@ Returns tuple of
 def vad_segment_generator(wavFile, aggressiveness):
     logging.debug("Caught the wav file @: %s" % (wavFile))
     audio, sample_rate, audio_length = wavSplit.read_wave(wavFile)
-    assert sample_rate == 16000, "Only 16000Hz input WAV files are supported for now!"
+    assert sample_rate == 8000, "Only 8000Hz input WAV files are supported for now!"
     vad = webrtcvad.Vad(int(aggressiveness))
     frames = wavSplit.frame_generator(30, audio, sample_rate)
     frames = list(frames)
diff --git a/examples/vad_transcriber/wavTranscription.md b/examples/vad_transcriber/wavTranscription.md
index 0d03c47..8f0a684 100644
--- a/examples/vad_transcriber/wavTranscription.md
+++ b/examples/vad_transcriber/wavTranscription.md
@@ -34,9 +34,9 @@ sample_rec.wav                 13.710               20.797               5.593
 
 ```
 
-**Note:** Only `wav` files with a 16kHz sample rate are supported for now, you can convert your files to the appropriate format with ffmpeg if available on your system.
+**Note:** Only `wav` files with a 8kHz sample rate are supported for now, you can convert your files to the appropriate format with ffmpeg if available on your system.
 
-    ffmpeg -i infile.mp3  -ar 16000 -ac 1  outfile.wav
+    ffmpeg -i infile.mp3  -ar 8000 -ac 1  outfile.wav
 
 ### 2. Minimalistic GUI
 
diff --git a/native_client/client.cc b/native_client/client.cc
index f1148eb..42f0abd 100644
--- a/native_client/client.cc
+++ b/native_client/client.cc
@@ -124,7 +124,7 @@ GetAudioBuffer(const char* path)
 
   // Resample/reformat the audio so we can pass it through the MFCC functions
   sox_signalinfo_t target_signal = {
-      16000, // Rate
+      8000, // Rate
       1, // Channels
       16, // Precision
       SOX_UNSPEC, // Length
@@ -163,8 +163,8 @@ GetAudioBuffer(const char* path)
 
   res.sample_rate = (int)output->signal.rate;
 
-  if ((int)input->signal.rate < 16000) {
-    fprintf(stderr, "Warning: original sample rate (%d) is lower than 16kHz. Up-sampling might produce erratic speech recognition.\n", (int)input->signal.rate);
+  if ((int)input->signal.rate < 8000) {
+    fprintf(stderr, "Warning: original sample rate (%d) is lower than 8kHz. Up-sampling might produce erratic speech recognition.\n", (int)input->signal.rate);
   }
 
   // Setup the effects chain to decode/resample
@@ -210,7 +210,7 @@ GetAudioBuffer(const char* path)
 #endif // NO_SOX
 
 #ifdef NO_SOX
-  // FIXME: Hack and support only 16kHz mono 16-bits PCM
+  // FIXME: Hack and support only 8kHz mono 16-bits PCM
   FILE* wave = fopen(path, "r");
 
   size_t rv;
@@ -229,7 +229,7 @@ GetAudioBuffer(const char* path)
 
   assert(audio_format == 1); // 1 is PCM
   assert(num_channels == 1); // MONO
-  assert(sample_rate == 16000); // 16000 Hz
+  assert(sample_rate == 8000); // 8000 Hz
   assert(bits_per_sample == 16); // 16 bits per sample
 
   fprintf(stderr, "audio_format=%d\n", audio_format);
diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc
index 74d83e5..e3fd490 100644
--- a/native_client/deepspeech.cc
+++ b/native_client/deepspeech.cc
@@ -39,7 +39,7 @@
 //TODO: infer batch size from model/use dynamic batch size
 constexpr unsigned int BATCH_SIZE = 1;
 
-constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
+constexpr unsigned int DEFAULT_SAMPLE_RATE = 8000;
 constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
 constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
 
diff --git a/native_client/python/__init__.py b/native_client/python/__init__.py
index f4923f8..edd2a89 100644
--- a/native_client/python/__init__.py
+++ b/native_client/python/__init__.py
@@ -37,7 +37,7 @@ class Model(object):
     def sttWithMetadata(self, *args, **kwargs):
         return deepspeech.impl.SpeechToTextWithMetadata(self._impl, *args, **kwargs)
 
-    def setupStream(self, pre_alloc_frames=150, sample_rate=16000):
+    def setupStream(self, pre_alloc_frames=150, sample_rate=8000):
         status, ctx = deepspeech.impl.SetupStream(self._impl,
                                                   aPreAllocFrames=pre_alloc_frames,
                                                   aSampleRate=sample_rate)
diff --git a/native_client/python/client.py b/native_client/python/client.py
index d52fa6b..dbeabcb 100644
--- a/native_client/python/client.py
+++ b/native_client/python/client.py
@@ -41,15 +41,15 @@ N_CONTEXT = 9
 
 
 def convert_samplerate(audio_path):
-    sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate 16000 --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path))
+    sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate 8000 --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path))
     try:
         output = subprocess.check_output(shlex.split(sox_cmd), stderr=subprocess.PIPE)
     except subprocess.CalledProcessError as e:
         raise RuntimeError('SoX returned non-zero status: {}'.format(e.stderr))
     except OSError as e:
-        raise OSError(e.errno, 'SoX not found, use 16kHz files or install it: {}'.format(e.strerror))
+        raise OSError(e.errno, 'SoX not found, use 8kHz files or install it: {}'.format(e.strerror))
 
-    return 16000, np.frombuffer(output, np.int16)
+    return 8000, np.frombuffer(output, np.int16)
 
 
 def metadata_to_string(metadata):
@@ -98,13 +98,13 @@ def main():
 
     fin = wave.open(args.audio, 'rb')
     fs = fin.getframerate()
-    if fs != 16000:
-        print('Warning: original sample rate ({}) is different than 16kHz. Resampling might produce erratic speech recognition.'.format(fs), file=sys.stderr)
+    if fs != 8000:
+        print('Warning: original sample rate ({}) is different than 8kHz. Resampling might produce erratic speech recognition.'.format(fs), file=sys.stderr)
         fs, audio = convert_samplerate(args.audio)
     else:
         audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
 
-    audio_length = fin.getnframes() * (1/16000)
+    audio_length = fin.getnframes() * (1/8000)
     fin.close()
 
     print('Running inference.', file=sys.stderr)
diff --git a/stats.py b/stats.py
index 466b78e..798b4a8 100644
--- a/stats.py
+++ b/stats.py
@@ -9,7 +9,7 @@ def main():
     parser = argparse.ArgumentParser()
 
     parser.add_argument("-csv", "--csv-files", help="Str. Filenames as a comma separated list", required=True)
-    parser.add_argument("--sample-rate", type=int, default=16000, required=False, help="Audio sample rate")
+    parser.add_argument("--sample-rate", type=int, default=8000, required=False, help="Audio sample rate")
     parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels")
     parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample")
     args = parser.parse_args()
diff --git a/util/config.py b/util/config.py
index 6204b99..0f84df7 100644
--- a/util/config.py
+++ b/util/config.py
@@ -70,7 +70,7 @@ def initialize_globals():
     # doc/Geometry.md
 
     # Number of MFCC features
-    c.n_input = 26 # TODO: Determine this programmatically from the sample rate
+    c.n_input = 13 # TODO: Determine this programmatically from the sample rate
 
     # The number of frames in the context
     c.n_context = 9 # TODO: Determine the optimal value using a validation data set
diff --git a/util/flags.py b/util/flags.py
index 3f634f4..bc000bb 100644
--- a/util/flags.py
+++ b/util/flags.py
@@ -22,7 +22,7 @@ def create_flags():
 
     f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
     f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
-    f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model')
+    f.DEFINE_integer('audio_sample_rate', 8000, 'sample rate value expected by model')
 
     # Global Constants
     # ================

Is there anything else I need to do to train directly on 8kHz data? Do I need to recompile the binaries if I call training like this?

python3 DeepSpeech.py <parameters>

I have changed the n_input parameter to 13 as described in doc/Geometry.rst, is there anything else I definitely need to take into account in terms of geometry?

Thanks!

For training, you don’t need the changes under native_client/. You do need those after, for running inference.

That seems complete enough, but I don’t recall if there are other variables that are sample-rate dependant.

You have 8kHz 16-bits PCM ?

Thanks!

Yes, exactly, the format is as follows:
RIFF (little-endian) data, WAVE audio, Microsoft PCM, 16 bit, mono 8000 Hz

I use sox for covertion.

Conversion from what to what ?

It should be possible to train with 8kHz, but there might be some hacking required, as we don’t yet really support that.

My original input data is formatted as mp3, so I need to run:

sox <input file> --bits 16 --channels 1 --rate 8000 --encoding signed-integer --endian little --compression 0.0 --no-dither <output file>

in order to get the PCM files. I am following here the command from client.py and only change the sampling rate parameter.

Would you have a hint how to find the places where hacking would be required? For now I searched for all occurences of 16000 and adjusted accordingly.

MP3 at 8kHz ?

No, you need to check by yourself

Yes, exactly. I want to target the transcription of voice over telephone, which are MP3 files.

The problem is that I have only 16 kHz PCM training data. So in order to have a similar training data I first save this 16 kHz data as 8 kHz MP3 and then convert to wave (because we can only feed WAV data to DeepSpeech). This is to reflect the process, that for inference I will do the inference of MP3 audio, which needs to be converted to PCM files (preferably 8 kHz to minimize the audio modification, this is what I am working on now).