I’d like to try to train 8kHz model trained directly on 8kHz data to compare its performance to model with 8kHz data upsampled to 16kHz. In order to achieve it I need to allow DeepSpeech to use 8 kHz data directly for training (and avoid upsampling).
The modification is here below (from git diff):
diff --git a/examples/mic_vad_streaming/mic_vad_streaming.py b/examples/mic_vad_streaming/mic_vad_streaming.py
index 6e7f499..a1e1907 100755
--- a/examples/mic_vad_streaming/mic_vad_streaming.py
+++ b/examples/mic_vad_streaming/mic_vad_streaming.py
@@ -16,7 +16,7 @@ class Audio(object):
FORMAT = pyaudio.paInt16
# Network/VAD rate-space
- RATE_PROCESS = 16000
+ RATE_PROCESS = 8000
CHANNELS = 1
BLOCKS_PER_SECOND = 50
@@ -66,7 +66,7 @@ class Audio(object):
return resample16.tostring()
def read_resampled(self):
- """Return a block of audio data resampled to 16000hz, blocking if necessary."""
+ """Return a block of audio data resampled to 8000Hz, blocking if necessary."""
return self.resample(data=self.buffer_queue.get(),
input_rate=self.input_rate)
@@ -189,7 +189,7 @@ def main(ARGS):
if __name__ == '__main__':
BEAM_WIDTH = 500
- DEFAULT_SAMPLE_RATE = 16000
+ DEFAULT_SAMPLE_RATE = 8000
LM_ALPHA = 0.75
LM_BETA = 1.85
N_FEATURES = 26
diff --git a/examples/vad_transcriber/wavTranscriber.py b/examples/vad_transcriber/wavTranscriber.py
index 2735879..7671121 100644
--- a/examples/vad_transcriber/wavTranscriber.py
+++ b/examples/vad_transcriber/wavTranscriber.py
@@ -46,7 +46,7 @@ Returns a list [Inference, Inference Time, Audio Length]
'''
def stt(ds, audio, fs):
inference_time = 0.0
- audio_length = len(audio) * (1 / 16000)
+ audio_length = len(audio) * (1 / 8000)
# Run Deepspeech
logging.debug('Running inference...')
@@ -94,7 +94,7 @@ Returns tuple of
def vad_segment_generator(wavFile, aggressiveness):
logging.debug("Caught the wav file @: %s" % (wavFile))
audio, sample_rate, audio_length = wavSplit.read_wave(wavFile)
- assert sample_rate == 16000, "Only 16000Hz input WAV files are supported for now!"
+ assert sample_rate == 8000, "Only 8000Hz input WAV files are supported for now!"
vad = webrtcvad.Vad(int(aggressiveness))
frames = wavSplit.frame_generator(30, audio, sample_rate)
frames = list(frames)
diff --git a/examples/vad_transcriber/wavTranscription.md b/examples/vad_transcriber/wavTranscription.md
index 0d03c47..8f0a684 100644
--- a/examples/vad_transcriber/wavTranscription.md
+++ b/examples/vad_transcriber/wavTranscription.md
@@ -34,9 +34,9 @@ sample_rec.wav 13.710 20.797 5.593
```
-**Note:** Only `wav` files with a 16kHz sample rate are supported for now, you can convert your files to the appropriate format with ffmpeg if available on your system.
+**Note:** Only `wav` files with a 8kHz sample rate are supported for now, you can convert your files to the appropriate format with ffmpeg if available on your system.
- ffmpeg -i infile.mp3 -ar 16000 -ac 1 outfile.wav
+ ffmpeg -i infile.mp3 -ar 8000 -ac 1 outfile.wav
### 2. Minimalistic GUI
diff --git a/native_client/client.cc b/native_client/client.cc
index f1148eb..42f0abd 100644
--- a/native_client/client.cc
+++ b/native_client/client.cc
@@ -124,7 +124,7 @@ GetAudioBuffer(const char* path)
// Resample/reformat the audio so we can pass it through the MFCC functions
sox_signalinfo_t target_signal = {
- 16000, // Rate
+ 8000, // Rate
1, // Channels
16, // Precision
SOX_UNSPEC, // Length
@@ -163,8 +163,8 @@ GetAudioBuffer(const char* path)
res.sample_rate = (int)output->signal.rate;
- if ((int)input->signal.rate < 16000) {
- fprintf(stderr, "Warning: original sample rate (%d) is lower than 16kHz. Up-sampling might produce erratic speech recognition.\n", (int)input->signal.rate);
+ if ((int)input->signal.rate < 8000) {
+ fprintf(stderr, "Warning: original sample rate (%d) is lower than 8kHz. Up-sampling might produce erratic speech recognition.\n", (int)input->signal.rate);
}
// Setup the effects chain to decode/resample
@@ -210,7 +210,7 @@ GetAudioBuffer(const char* path)
#endif // NO_SOX
#ifdef NO_SOX
- // FIXME: Hack and support only 16kHz mono 16-bits PCM
+ // FIXME: Hack and support only 8kHz mono 16-bits PCM
FILE* wave = fopen(path, "r");
size_t rv;
@@ -229,7 +229,7 @@ GetAudioBuffer(const char* path)
assert(audio_format == 1); // 1 is PCM
assert(num_channels == 1); // MONO
- assert(sample_rate == 16000); // 16000 Hz
+ assert(sample_rate == 8000); // 8000 Hz
assert(bits_per_sample == 16); // 16 bits per sample
fprintf(stderr, "audio_format=%d\n", audio_format);
diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc
index 74d83e5..e3fd490 100644
--- a/native_client/deepspeech.cc
+++ b/native_client/deepspeech.cc
@@ -39,7 +39,7 @@
//TODO: infer batch size from model/use dynamic batch size
constexpr unsigned int BATCH_SIZE = 1;
-constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
+constexpr unsigned int DEFAULT_SAMPLE_RATE = 8000;
constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
diff --git a/native_client/python/__init__.py b/native_client/python/__init__.py
index f4923f8..edd2a89 100644
--- a/native_client/python/__init__.py
+++ b/native_client/python/__init__.py
@@ -37,7 +37,7 @@ class Model(object):
def sttWithMetadata(self, *args, **kwargs):
return deepspeech.impl.SpeechToTextWithMetadata(self._impl, *args, **kwargs)
- def setupStream(self, pre_alloc_frames=150, sample_rate=16000):
+ def setupStream(self, pre_alloc_frames=150, sample_rate=8000):
status, ctx = deepspeech.impl.SetupStream(self._impl,
aPreAllocFrames=pre_alloc_frames,
aSampleRate=sample_rate)
diff --git a/native_client/python/client.py b/native_client/python/client.py
index d52fa6b..dbeabcb 100644
--- a/native_client/python/client.py
+++ b/native_client/python/client.py
@@ -41,15 +41,15 @@ N_CONTEXT = 9
def convert_samplerate(audio_path):
- sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate 16000 --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path))
+ sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate 8000 --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path))
try:
output = subprocess.check_output(shlex.split(sox_cmd), stderr=subprocess.PIPE)
except subprocess.CalledProcessError as e:
raise RuntimeError('SoX returned non-zero status: {}'.format(e.stderr))
except OSError as e:
- raise OSError(e.errno, 'SoX not found, use 16kHz files or install it: {}'.format(e.strerror))
+ raise OSError(e.errno, 'SoX not found, use 8kHz files or install it: {}'.format(e.strerror))
- return 16000, np.frombuffer(output, np.int16)
+ return 8000, np.frombuffer(output, np.int16)
def metadata_to_string(metadata):
@@ -98,13 +98,13 @@ def main():
fin = wave.open(args.audio, 'rb')
fs = fin.getframerate()
- if fs != 16000:
- print('Warning: original sample rate ({}) is different than 16kHz. Resampling might produce erratic speech recognition.'.format(fs), file=sys.stderr)
+ if fs != 8000:
+ print('Warning: original sample rate ({}) is different than 8kHz. Resampling might produce erratic speech recognition.'.format(fs), file=sys.stderr)
fs, audio = convert_samplerate(args.audio)
else:
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
- audio_length = fin.getnframes() * (1/16000)
+ audio_length = fin.getnframes() * (1/8000)
fin.close()
print('Running inference.', file=sys.stderr)
diff --git a/stats.py b/stats.py
index 466b78e..798b4a8 100644
--- a/stats.py
+++ b/stats.py
@@ -9,7 +9,7 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("-csv", "--csv-files", help="Str. Filenames as a comma separated list", required=True)
- parser.add_argument("--sample-rate", type=int, default=16000, required=False, help="Audio sample rate")
+ parser.add_argument("--sample-rate", type=int, default=8000, required=False, help="Audio sample rate")
parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels")
parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample")
args = parser.parse_args()
diff --git a/util/config.py b/util/config.py
index 6204b99..0f84df7 100644
--- a/util/config.py
+++ b/util/config.py
@@ -70,7 +70,7 @@ def initialize_globals():
# doc/Geometry.md
# Number of MFCC features
- c.n_input = 26 # TODO: Determine this programmatically from the sample rate
+ c.n_input = 13 # TODO: Determine this programmatically from the sample rate
# The number of frames in the context
c.n_context = 9 # TODO: Determine the optimal value using a validation data set
diff --git a/util/flags.py b/util/flags.py
index 3f634f4..bc000bb 100644
--- a/util/flags.py
+++ b/util/flags.py
@@ -22,7 +22,7 @@ def create_flags():
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
- f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model')
+ f.DEFINE_integer('audio_sample_rate', 8000, 'sample rate value expected by model')
# Global Constants
# ================
Is there anything else I need to do to train directly on 8kHz data? Do I need to recompile the binaries if I call training like this?
python3 DeepSpeech.py <parameters>
I have changed the n_input
parameter to 13
as described in doc/Geometry.rst, is there anything else I definitely need to take into account in terms of geometry?
Thanks!