Source code for persephone.model

""" Generic model for automatic speech recognition. """

import inspect
import itertools
import logging
import math
import os
from pathlib import Path
import sys
from typing import Callable, Optional, Union, Sequence, Set, List, Dict

import tensorflow as tf

from .preprocess import labels, feat_extract
from . import utils
from . import config
from .config import ENCODING
from .corpus import Corpus
from .exceptions import PersephoneException
from .corpus_reader import CorpusReader

allow_growth_config = tf.ConfigProto(log_device_placement=False)
allow_growth_config.gpu_options.allow_growth = True #pylint: disable=no-member

logger = logging.getLogger(__name__) # type: ignore

def load_metagraph(model_path_prefix: Union[str, Path]) -> tf.train.Saver:
    """ Given the path to a model on disk (these will typically be found in
    directories such as exp/<exp_num>/model/model_best.*) creates a Saver
    object that can then be used to restore the graph inside a tf.Session.
    """

    model_path_prefix = str(model_path_prefix)
    metagraph = tf.train.import_meta_graph(model_path_prefix + ".meta")
    return metagraph

def dense_to_human_readable(dense_repr: Sequence[Sequence[int]], index_to_label: Dict[int, str]) -> List[List[str]]:
    """ Converts a dense representation of model decoded output into human
    readable, using a mapping from indices to labels. """

    transcripts = []
    for dense_r in dense_repr:
        non_empty_phonemes = [phn_i for phn_i in dense_r if phn_i != 0]
        transcript = [index_to_label[index] for index in non_empty_phonemes]
        transcripts.append(transcript)

    return transcripts

def decode_corpus(model_path_prefix: Union[str, Path],
                  corpus: Corpus,
                  *,
                  batch_size: int = 64,
                  feat_dir: Optional[Path]=None,
                  batch_x_name: str="batch_x:0",
                  batch_x_lens_name: str="batch_x_lens:0",
                  output_name: str="hyp_dense_decoded:0") -> List[List[str]]:
    input_paths = [Path(corpus.tgt_dir) / "wav" / Path(prefix + ".wav")
                   for prefix in corpus.untranscribed_prefixes]
    return decode(model_path_prefix,
           input_paths,
           label_set=corpus.labels,
           feature_type=corpus.feat_type,
           batch_size=batch_size,
           feat_dir=feat_dir,
           batch_x_name=batch_x_name,
           batch_x_lens_name=batch_x_lens_name,
           output_name=output_name)

def decode(model_path_prefix: Union[str, Path],
           input_paths: Sequence[Path],
           label_set: Set[str],
           *,
           feature_type: str = "fbank", #TODO Make this None and infer feature_type from dimension of NN input layer.
           batch_size: int = 64,
           feat_dir: Optional[Path]=None,
           batch_x_name: str="batch_x:0",
           batch_x_lens_name: str="batch_x_lens:0",
           output_name: str="hyp_dense_decoded:0") -> List[List[str]]:
    """Use an existing tensorflow model that exists on disk to decode
    WAV files.

    Args:
        model_path_prefix: The path to the saved tensorflow model.
                           This is the full prefix to the ".ckpt" file.
        input_paths: A sequence of `pathlib.Path`s to WAV files to put through
                     the model provided.
        label_set: The set of all the labels this model uses.
        feature_type: The type of features this model uses.
                      Note that this MUST match the type of features that the
                      model was trained on initially.
        feat_dir: Any files that require preprocessing will be
                                  saved to the path specified by this.
        batch_x_name: The name of the tensorflow input for batch_x
        batch_x_lens_name: The name of the tensorflow input for batch_x_lens
        output_name: The name of the tensorflow output
    """

    if not input_paths:
        raise PersephoneException("No untranscribed WAVs to transcribe.")

    model_path_prefix = str(model_path_prefix)

    for p in input_paths:
        if not p.exists():
            raise PersephoneException(
                "The WAV file path {} does not exist".format(p)
            )

    preprocessed_file_paths = []
    for p in input_paths:
        prefix = p.stem
        # Check the "feat" directory as per the filesystem conventions of a Corpus
        feature_file_ext = ".{}.npy".format(feature_type)
        conventional_npy_location =  p.parent.parent / "feat" / (Path(prefix + feature_file_ext))
        if conventional_npy_location.exists():
            # don't need to preprocess it
            preprocessed_file_paths.append(conventional_npy_location)
        else:
            if not feat_dir:
                feat_dir = p.parent.parent / "feat"
            if not feat_dir.is_dir():
                os.makedirs(str(feat_dir))

            mono16k_wav_path = feat_dir / "{}.wav".format(prefix)
            feat_path = feat_dir / "{}.{}.npy".format(prefix, feature_type)
            feat_extract.convert_wav(p, mono16k_wav_path)
            preprocessed_file_paths.append(feat_path)
    # preprocess the file that weren't found in the features directory
    # as per the filesystem conventions
    if feat_dir:
        feat_extract.from_dir(feat_dir, feature_type)

    fn_batches = utils.make_batches(preprocessed_file_paths, batch_size)
    # Load the model and perform decoding.
    metagraph = load_metagraph(model_path_prefix)
    with tf.Session() as sess:
        metagraph.restore(sess, model_path_prefix)

        for fn_batch in fn_batches:
            batch_x, batch_x_lens = utils.load_batch_x(fn_batch)

        # TODO These placeholder names should be a backup if names from a newer
        # naming scheme aren't present. Otherwise this won't generalize to
        # different architectures.
        feed_dict = {batch_x_name: batch_x,
                     batch_x_lens_name: batch_x_lens}

        dense_decoded = sess.run(output_name, feed_dict=feed_dict)

    # Create a human-readable representation of the decoded.
    indices_to_labels = labels.make_indices_to_labels(label_set)
    human_readable = dense_to_human_readable(dense_decoded, indices_to_labels)

    return human_readable

[docs]class Model: """ Generic model for our ASR tasks. Attributes: exp_dir: Path that the experiment directory is located corpus_reader: `CorpusReader` object that provides access to the corpus this model is being trained on. log_softmax: log softmax function batch_x: A batch of input features. ("x" is the typical notation in ML papers on this topic denoting model input) batch_x_lens: The lengths of each utterance. This is used by Tensorflow to know how much to pad utterances that are shorter than this length. batch_y: Reference labels for a batch ("y" is the typical notation in ML papers on this topic denoting training labels) optimizer: The gradient descent method being used. (Typically we use Adam because it has provided good results but any stochastic gradient descent method could be substituted here) ler: Label error rate. dense_decoded: Dense representation of the model transcription output. dense_ref: Dense representation of the reference transcription. saved_model_path: Path to where the Tensorflow model is being saved on disk. """
[docs] def __init__(self, exp_dir: Union[Path, str], corpus_reader: CorpusReader) -> None: self.exp_dir = str(exp_dir) if isinstance(exp_dir, Path) else exp_dir # type: str self.corpus_reader = corpus_reader self.log_softmax = None self.batch_x = None self.batch_x_lens = None self.batch_y = None self.optimizer = None self.ler = None self.dense_decoded = None self.dense_ref = None self.saved_model_path = "" # type: str
[docs] def transcribe(self, restore_model_path: Optional[str]=None) -> None: """ Transcribes an untranscribed dataset. Similar to eval() except no reference translation is assumed, thus no LER is calculated. """ saver = tf.train.Saver() with tf.Session(config=allow_growth_config) as sess: if restore_model_path: saver.restore(sess, restore_model_path) else: if self.saved_model_path: saver.restore(sess, self.saved_model_path) else: raise PersephoneException("No model to use for transcription.") batch_gen = self.corpus_reader.untranscribed_batch_gen() hyp_batches = [] for batch_i, batch in enumerate(batch_gen): batch_x, batch_x_lens, feat_fn_batch = batch feed_dict = {self.batch_x: batch_x, self.batch_x_lens: batch_x_lens} [dense_decoded] = sess.run([self.dense_decoded], feed_dict=feed_dict) hyps = self.corpus_reader.human_readable(dense_decoded) # Prepare dir for transcription hyps_dir = os.path.join(self.exp_dir, "transcriptions") if not os.path.isdir(hyps_dir): os.mkdir(hyps_dir) hyp_batches.append((hyps,feat_fn_batch)) with open(os.path.join(hyps_dir, "hyps.txt"), "w", encoding=ENCODING) as hyps_f: for hyp_batch, fn_batch in hyp_batches: for hyp, fn in zip(hyp_batch, fn_batch): print(fn, file=hyps_f) print(" ".join(hyp), file=hyps_f) print("", file=hyps_f)
def decode(self): model_path_prefix = Path(self.exp_dir) / "model" / "model_best.ckpt" prefixes = self.corpus_reader.corpus.untranscribed_prefixes input_paths = [self.corpus_reader.corpus.tgt_dir / "feat" / Path(p + ".wav") for p in prefixes] label_set = self.corpus_reader.corpus.labels feature_type = self.corpus_reader.corpus.feat_type batch_size = self.corpus_reader.batch_size batch_x_name = self.batch_x.name batch_x_lens_name = self.batch_x_lens.name output_name = self.dense_decoded.name return decode(model_path_prefix, input_paths, label_set, feature_type=feature_type, batch_size=batch_size, batch_x_name=batch_x_name, batch_x_lens_name=batch_x_lens_name, output_name=output_name) def eval(self, restore_model_path: Optional[str]=None) -> None: """ Evaluates the model on a test set.""" saver = tf.train.Saver() with tf.Session(config=allow_growth_config) as sess: if restore_model_path: logger.info("restoring model from %s", restore_model_path) saver.restore(sess, restore_model_path) else: assert self.saved_model_path, "{}".format(self.saved_model_path) logger.info("restoring model from %s", self.saved_model_path) saver.restore(sess, self.saved_model_path) test_x, test_x_lens, test_y = self.corpus_reader.test_batch() feed_dict = {self.batch_x: test_x, self.batch_x_lens: test_x_lens, self.batch_y: test_y} test_ler, dense_decoded, dense_ref = sess.run( [self.ler, self.dense_decoded, self.dense_ref], feed_dict=feed_dict) hyps, refs = self.corpus_reader.human_readable_hyp_ref( dense_decoded, dense_ref) # Log hypotheses hyps_dir = os.path.join(self.exp_dir, "test") if not os.path.isdir(hyps_dir): os.mkdir(hyps_dir) with open(os.path.join(hyps_dir, "hyps"), "w", encoding=ENCODING) as hyps_f: for hyp in hyps: print(" ".join(hyp), file=hyps_f) with open(os.path.join(hyps_dir, "refs"), "w", encoding=ENCODING) as refs_f: for ref in refs: print(" ".join(ref), file=refs_f) test_per = utils.batch_per(hyps, refs) if not math.isclose(test_per, test_ler, rel_tol=1e-07): logger.warning("The label error rate from Tensorflow doesn't exactly" "match the phoneme error rate calculated in persephone" "Tensorflow %f, Persephone %f", test_ler, test_per) with open(os.path.join(hyps_dir, "test_per"), "w", encoding=ENCODING) as per_f: print("LER: %f" % (test_ler), file=per_f) def output_best_scores(self, best_epoch_str: str) -> None: """Output best scores to the filesystem""" BEST_SCORES_FILENAME = "best_scores.txt" with open(os.path.join(self.exp_dir, BEST_SCORES_FILENAME), "w", encoding=ENCODING) as best_f: print(best_epoch_str, file=best_f, flush=True)
[docs] def train(self, *, early_stopping_steps: int = 10, min_epochs: int = 30, max_valid_ler: float = 1.0, max_train_ler: float = 0.3, max_epochs: int = 100, restore_model_path: Optional[str]=None, epoch_callback: Optional[Callable[[Dict], None]]=None) -> None: """ Train the model. min_epochs: minimum number of epochs to run training for. max_epochs: maximum number of epochs to run training for. early_stopping_steps: Stop training after this number of steps if no LER improvement has been made. max_valid_ler: Maximum LER for the validation set. Training will continue until this is met or another stopping condition occurs. max_train_ler: Maximum LER for the training set. Training will continue until this is met or another stopping condition occurs. restore_model_path: The path to restore a model from. epoch_callback: A callback that is called at the end of each training epoch. The parameters passed to the callable will be the epoch number, the current training LER and the current validation LER. This can be useful for progress reporting. """ logger.info("Training model") best_valid_ler = 2.0 steps_since_last_record = 0 #Get information about training for the names of output files. frame = inspect.currentframe() # pylint: disable=deprecated-method # It was a mistake to deprecate this in Python 3.5 if frame: args, _, _, values = inspect.getargvalues(frame) with open(os.path.join(self.exp_dir, "train_description.txt"), "w", encoding=ENCODING) as desc_f: for arg in args: if type(values[arg]) in [str, int, float] or isinstance( values[arg], type(None)): print("%s=%s" % (arg, values[arg]), file=desc_f) else: print("%s=%s" % (arg, values[arg].__dict__), file=desc_f) print("num_train=%s" % (self.corpus_reader.num_train), file=desc_f) print("batch_size=%s" % (self.corpus_reader.batch_size), file=desc_f) else: logger.error("Couldn't find frame information, failed to write train_description.txt") # Load the validation set valid_x, valid_x_lens, valid_y = self.corpus_reader.valid_batch() saver = tf.train.Saver() with tf.Session(config=allow_growth_config) as sess: if restore_model_path: logger.info("Restoring model from path %s", restore_model_path) saver.restore(sess, restore_model_path) else: sess.run(tf.global_variables_initializer()) # Prepare directory to output hypotheses to hyps_dir = os.path.join(self.exp_dir, "decoded") if not os.path.isdir(hyps_dir): os.mkdir(hyps_dir) best_epoch_str = None training_log_path = os.path.join(self.exp_dir, "train_log.txt") if os.path.exists(training_log_path): logger.error("Error, overwriting existing log file at path {}".format(training_log_path)) with open(training_log_path, "w", encoding=ENCODING) as out_file: for epoch in itertools.count(start=1): print("\nexp_dir %s, epoch %d" % (self.exp_dir, epoch)) batch_gen = self.corpus_reader.train_batch_gen() train_ler_total = 0 print("\tBatch...", end="") for batch_i, batch in enumerate(batch_gen): print("%d..." % batch_i, end="") sys.stdout.flush() batch_x, batch_x_lens, batch_y = batch feed_dict = {self.batch_x: batch_x, self.batch_x_lens: batch_x_lens, self.batch_y: batch_y} _, ler, = sess.run([self.optimizer, self.ler], feed_dict=feed_dict) train_ler_total += ler #else: # raise PersephoneException("No training data was provided." # " Check your batch generation.") feed_dict = {self.batch_x: valid_x, self.batch_x_lens: valid_x_lens, self.batch_y: valid_y} try: valid_ler, dense_decoded, dense_ref = sess.run( [self.ler, self.dense_decoded, self.dense_ref], feed_dict=feed_dict) except tf.errors.ResourceExhaustedError: import pprint print("Ran out of memory allocating a batch:") pprint.pprint(feed_dict) logger.critical("Ran out of memory allocating a batch: %s", pprint.pformat(feed_dict)) raise hyps, refs = self.corpus_reader.human_readable_hyp_ref( dense_decoded, dense_ref) # Log hypotheses with open(os.path.join(hyps_dir, "epoch%d_hyps" % epoch), "w", encoding=ENCODING) as hyps_f: for hyp in hyps: print(" ".join(hyp), file=hyps_f) if epoch == 1: with open(os.path.join(hyps_dir, "refs"), "w", encoding=ENCODING) as refs_f: for ref in refs: print(" ".join(ref), file=refs_f) valid_per = utils.batch_per(hyps, refs) epoch_str = "Epoch %d. Training LER: %f, validation LER: %f" % ( epoch, (train_ler_total / (batch_i + 1)), valid_ler) print(epoch_str, flush=True, file=out_file) if best_epoch_str is None: best_epoch_str = epoch_str # Call the callback here if it was defined if epoch_callback: epoch_callback({ "epoch": epoch, "training_ler": (train_ler_total / (batch_i + 1)), # current training LER "valid_ler": valid_ler, # Current validation LER }) # Implement early stopping. if valid_ler < best_valid_ler: print("New best valid_ler", file=out_file) best_valid_ler = valid_ler best_epoch_str = epoch_str steps_since_last_record = 0 # Save the model. checkpoint_path = os.path.join(self.exp_dir, "model", "model_best.ckpt") if not os.path.exists(os.path.dirname(checkpoint_path)): os.mkdir(os.path.dirname(checkpoint_path)) saver.save(sess, checkpoint_path) self.saved_model_path = checkpoint_path # Output best hyps with open(os.path.join(hyps_dir, "best_hyps"), "w", encoding=ENCODING) as hyps_f: for hyp in hyps: print(" ".join(hyp), file=hyps_f) else: print("Steps since last best valid_ler: %d" % (steps_since_last_record), file=out_file) steps_since_last_record += 1 if epoch >= max_epochs: self.output_best_scores(best_epoch_str) break if steps_since_last_record >= early_stopping_steps: if epoch >= min_epochs: # Then we've done the minimum number of epochs. if valid_ler <= max_valid_ler and ler <= max_train_ler: # Then training error has moved sufficiently # towards convergence. print("Stopping since best validation score hasn't been" " beaten in %d epochs and at least %d have been" " done. The valid ler (%d) is below %d and" " the train ler (%d) is below %d." % (early_stopping_steps, min_epochs, valid_ler, max_valid_ler, ler, max_train_ler), file=out_file, flush=True) self.output_best_scores(best_epoch_str) break else: # Keep training because we haven't achieved # convergence. continue else: # Keep training because we haven't done the minimum # numper of epochs. continue # Check we actually saved a checkpoint if not self.saved_model_path: raise PersephoneException( "No checkpoint was saved so model evaluation cannot be performed. " "This can happen if the validaion LER never converges.") # Finally, run evaluation on the test set. self.eval(restore_model_path=self.saved_model_path)