# Automatically reload imported modules that are changed outside this notebook
%load_ext autoreload
%autoreload 2
# More pixels in figures
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams["figure.dpi"] = 200
# Init PRNG with fixed seed for reproducibility
import numpy as np
np_rng = np.random.default_rng(1)
import tensorflow as tf
tf.random.set_seed(np_rng.integers(0, tf.int64.max))
2020-11-21
In this example, we take a different approach for training language vectors (embeddings) compared to common-voice-embeddings
.
Previously, we trained a neural network on a classification task and used one of its layers as the representation for different classes.
In this example, we train a neural network directly on the language vector task by maximizing the angular distance between vectors of different classes.
We'll be using the approach described by G. Gelly and J.L. Gauvain.
We will continue with the same, 4-language Common Voice data as in all previous examples.
import urllib.parse
from IPython.display import display, Markdown
languages = """
et
mn
ta
tr
""".split()
languages = sorted(l.strip() for l in languages)
display(Markdown("### Languages"))
display(Markdown('\n'.join("* `{}`".format(l) for l in languages)))
bcp47_validator_url = 'https://schneegans.de/lv/?tags='
display(Markdown("See [this tool]({}) for a description of the BCP-47 language codes."
.format(bcp47_validator_url + urllib.parse.quote('\n'.join(languages)))))
import os
from lidbox.meta import (
common_voice,
generate_label2target,
verify_integrity,
read_audio_durations,
random_oversampling_on_split
)
workdir = "/data/exp/cv4-angular-lstm"
datadir = "/mnt/data/speech/common-voice/downloads/2020/cv-corpus"
print("work dir:", workdir)
print("data source dir:", datadir)
print()
os.makedirs(workdir, exist_ok=True)
assert os.path.isdir(datadir), datadir + " does not exist"
dirs = sorted((f for f in os.scandir(datadir) if f.is_dir()), key=lambda f: f.name)
print(datadir)
for d in dirs:
if d.name in languages:
print(' ', d.name)
for f in os.scandir(d):
print(' ', f.name)
missing_languages = set(languages) - set(d.name for d in dirs)
assert missing_languages == set(), "missing languages: {}".format(missing_languages)
meta = common_voice.load_all(datadir, languages)
meta, lang2target = generate_label2target(meta)
print("\nsize of all metadata", meta.shape)
meta = meta.dropna()
print("after dropping NaN rows", meta.shape)
print("verifying integrity")
verify_integrity(meta)
print("ok\n")
print("reading audio durations")
meta["duration"] = read_audio_durations(meta)
print("balancing the label distributions")
meta = random_oversampling_on_split(meta, "train")
work dir: /data/exp/cv4-angular-lstm data source dir: /mnt/data/speech/common-voice/downloads/2020/cv-corpus /mnt/data/speech/common-voice/downloads/2020/cv-corpus et validated.tsv invalidated.tsv other.tsv dev.tsv train.tsv clips test.tsv reported.tsv mn validated.tsv invalidated.tsv other.tsv dev.tsv train.tsv clips test.tsv reported.tsv ta validated.tsv invalidated.tsv other.tsv dev.tsv train.tsv clips test.tsv reported.tsv tr validated.tsv invalidated.tsv other.tsv dev.tsv train.tsv clips test.tsv reported.tsv size of all metadata (23842, 6) after dropping NaN rows (23842, 6) verifying integrity ok reading audio durations balancing the label distributions
Most of the preprocessing will be as in common-voice-embeddings
, but this time we will not be training on samples with varying length.
We will make these changes:
tf.keras.Model.fit
assumes the training set length does not change. This could probably be fixed by writing a custom training loop but we won't be doing that here.import scipy.signal
from lidbox.features import audio, cmvn
import lidbox.data.steps as ds_steps
TF_AUTOTUNE = tf.data.experimental.AUTOTUNE
def metadata_to_dataset_input(meta):
return {
"id": tf.constant(meta.index, tf.string),
"path": tf.constant(meta.path, tf.string),
"label": tf.constant(meta.label, tf.string),
"target": tf.constant(meta.target, tf.int32),
"split": tf.constant(meta.split, tf.string),
"is_copy": tf.constant(meta.is_copy, tf.bool),
}
def read_mp3(x):
s, r = audio.read_mp3(x["path"])
out_rate = 16000
s = audio.resample(s, r, out_rate)
s = audio.peak_normalize(s, dBFS=-3.0)
s = audio.remove_silence(s, out_rate)
return dict(x, signal=s, sample_rate=out_rate)
def random_filter(x):
def scipy_filter(s, N=10):
b = np_rng.normal(0, 1, N)
return scipy.signal.lfilter(b, 1.0, s).astype(np.float32), b
s, _ = tf.numpy_function(
scipy_filter,
[x["signal"]],
[tf.float32, tf.float64],
name="np_random_filter")
s = tf.cast(s, tf.float32)
s = audio.peak_normalize(s, dBFS=-3.0)
return dict(x, signal=s)
def random_speed_change(ds):
return ds_steps.random_signal_speed_change(ds, min=0.9, max=1.1, flag="is_copy")
def create_signal_chunks(ds):
ds = ds_steps.repeat_too_short_signals(ds, 3200)
ds = ds_steps.create_signal_chunks(ds, 3200, 800)
return ds
def batch_extract_features(x):
with tf.device("GPU"):
signals, rates = x["signal"], x["sample_rate"]
S = audio.spectrograms(signals, rates[0])
S = audio.linear_to_mel(S, rates[0])
S = tf.math.log(S + 1e-6)
S = cmvn(S, normalize_variance=False)
return dict(x, logmelspec=S)
def pipeline_from_meta(data, split):
if split == "train":
data = data.sample(frac=1, random_state=np_rng.bit_generator)
ds = (tf.data.Dataset
.from_tensor_slices(metadata_to_dataset_input(data))
.map(read_mp3, num_parallel_calls=TF_AUTOTUNE))
if split == "train":
return (ds
.apply(random_speed_change)
.cache(os.path.join(cachedir, "data", split))
.prefetch(100)
.map(random_filter, num_parallel_calls=TF_AUTOTUNE)
.apply(create_signal_chunks)
.batch(100)
.map(batch_extract_features, num_parallel_calls=TF_AUTOTUNE)
.unbatch())
else:
return (ds
.apply(create_signal_chunks)
.batch(100)
.map(batch_extract_features, num_parallel_calls=TF_AUTOTUNE)
.unbatch()
.cache(os.path.join(cachedir, "data", split))
.prefetch(100))
cachedir = os.path.join(workdir, "cache")
os.makedirs(os.path.join(cachedir, "data"))
split2ds = {split: pipeline_from_meta(meta[meta["split"]==split], split)
for split in meta.split.unique()}
2020-11-21 23:05:37.173 I lidbox.data.steps: Repeating all signals until they are at least 3200 ms 2020-11-21 23:05:37.214 I lidbox.data.steps: Dividing every signal in the dataset into new signals by creating signal chunks of length 3200 ms and offset 800 ms. Maximum amount of padding allowed in the last chunk is 0 ms. 2020-11-21 23:05:37.883 I lidbox.data.steps: Applying random resampling to signals with a random speed ratio chosen uniformly at random from [0.900, 1.100] 2020-11-21 23:05:37.993 I lidbox.data.steps: Repeating all signals until they are at least 3200 ms 2020-11-21 23:05:38.000 I lidbox.data.steps: Dividing every signal in the dataset into new signals by creating signal chunks of length 3200 ms and offset 800 ms. Maximum amount of padding allowed in the last chunk is 0 ms. 2020-11-21 23:05:38.165 I lidbox.data.steps: Repeating all signals until they are at least 3200 ms 2020-11-21 23:05:38.175 I lidbox.data.steps: Dividing every signal in the dataset into new signals by creating signal chunks of length 3200 ms and offset 800 ms. Maximum amount of padding allowed in the last chunk is 0 ms.
for split, ds in split2ds.items():
print("filling", split, "cache")
_ = ds_steps.consume(ds, log_interval=5000)
filling test cache 2020-11-21 23:05:38.335 I lidbox.data.steps: Exhausting the dataset iterator by iterating over all elements, log_interval = 5000 2020-11-21 23:05:48.470 I lidbox.data.steps: 5000 done, 493.371 elements per second. 2020-11-21 23:05:57.115 I lidbox.data.steps: 10000 done, 578.409 elements per second. 2020-11-21 23:06:05.999 I lidbox.data.steps: 15000 done, 562.872 elements per second. 2020-11-21 23:06:17.073 I lidbox.data.steps: 20000 done, 451.544 elements per second. 2020-11-21 23:06:19.066 I lidbox.data.steps: 21857 done, 932.154 elements per second. filling train cache 2020-11-21 23:06:19.070 I lidbox.data.steps: Exhausting the dataset iterator by iterating over all elements, log_interval = 5000 2020-11-21 23:06:31.611 I lidbox.data.steps: 5000 done, 398.694 elements per second. 2020-11-21 23:06:42.470 I lidbox.data.steps: 10000 done, 460.478 elements per second. 2020-11-21 23:06:53.318 I lidbox.data.steps: 15000 done, 460.963 elements per second. 2020-11-21 23:07:03.804 I lidbox.data.steps: 20000 done, 476.864 elements per second. 2020-11-21 23:07:14.153 I lidbox.data.steps: 25000 done, 483.173 elements per second. 2020-11-21 23:07:24.768 I lidbox.data.steps: 30000 done, 471.060 elements per second. 2020-11-21 23:07:35.201 I lidbox.data.steps: 35000 done, 479.282 elements per second. 2020-11-21 23:07:45.518 I lidbox.data.steps: 40000 done, 484.693 elements per second. 2020-11-21 23:07:52.277 I lidbox.data.steps: 43762 done, 556.673 elements per second. filling dev cache 2020-11-21 23:07:52.280 I lidbox.data.steps: Exhausting the dataset iterator by iterating over all elements, log_interval = 5000 2020-11-21 23:08:02.294 I lidbox.data.steps: 5000 done, 499.328 elements per second. 2020-11-21 23:08:10.649 I lidbox.data.steps: 10000 done, 598.444 elements per second. 2020-11-21 23:08:22.067 I lidbox.data.steps: 15000 done, 437.951 elements per second. 2020-11-21 23:08:27.765 I lidbox.data.steps: 20000 done, 877.561 elements per second. 2020-11-21 23:08:31.175 I lidbox.data.steps: 21967 done, 577.076 elements per second.
lidbox
implements both the model and the angular proximity loss function used in the reference paper.
The loss function aims to maximize the cosine distance of language vectors of different languages and minimize the distance for vectors of the same language.
Reference vectors will be generated for each class such that all reference vectors are orthogonal to each other.
In addition, we'll add random channel dropout to avoid overfitting on noise, as in the common-voice-small
example.
from lidbox.models import ap_lstm
from lidbox.losses import SparseAngularProximity
def create_model(num_freq_bins=40, num_labels=len(lang2target)):
m = ap_lstm.create(
input_shape=[None, num_freq_bins],
num_outputs=num_labels,
num_lstm_units=200,
channel_dropout_rate=0.8)
m.compile(
loss=SparseAngularProximity(num_labels, m.output.shape[1]),
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3))
return m
model = create_model()
model.summary()
Model: "angular_proximity_lstm" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input (InputLayer) [(None, None, 40)] 0 __________________________________________________________________________________________________ channel_dropout (SpatialDropout (None, None, 40) 0 input[0][0] __________________________________________________________________________________________________ blstm_1 (Bidirectional) (None, None, 400) 385600 channel_dropout[0][0] __________________________________________________________________________________________________ blstm_2 (Bidirectional) (None, None, 400) 961600 blstm_1[0][0] __________________________________________________________________________________________________ tf_op_layer_Mul (TensorFlowOpLa [(None, None, 400)] 0 blstm_1[0][0] __________________________________________________________________________________________________ tf_op_layer_Mul_1 (TensorFlowOp [(None, None, 400)] 0 blstm_2[0][0] __________________________________________________________________________________________________ blstm_concat (Concatenate) (None, None, 800) 0 tf_op_layer_Mul[0][0] tf_op_layer_Mul_1[0][0] __________________________________________________________________________________________________ avg_over_time (GlobalAveragePoo (None, 800) 0 blstm_concat[0][0] __________________________________________________________________________________________________ tf_op_layer_Square (TensorFlowO [(None, 800)] 0 avg_over_time[0][0] __________________________________________________________________________________________________ tf_op_layer_Sum (TensorFlowOpLa [(None, 1)] 0 tf_op_layer_Square[0][0] __________________________________________________________________________________________________ tf_op_layer_Maximum (TensorFlow [(None, 1)] 0 tf_op_layer_Sum[0][0] __________________________________________________________________________________________________ tf_op_layer_Rsqrt (TensorFlowOp [(None, 1)] 0 tf_op_layer_Maximum[0][0] __________________________________________________________________________________________________ tf_op_layer_Mul_2 (TensorFlowOp [(None, 800)] 0 avg_over_time[0][0] tf_op_layer_Rsqrt[0][0] ================================================================================================== Total params: 1,347,200 Trainable params: 1,347,200 Non-trainable params: 0 __________________________________________________________________________________________________
callbacks = [
tf.keras.callbacks.TensorBoard(
log_dir=os.path.join(cachedir, "tensorboard", model.name),
update_freq="epoch",
write_images=True,
profile_batch=0,
),
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=10,
),
tf.keras.callbacks.ModelCheckpoint(
os.path.join(cachedir, "model", model.name),
monitor='val_loss',
save_weights_only=True,
save_best_only=True,
verbose=1,
),
]
def as_model_input(x):
return x["logmelspec"], x["target"]
train_ds = split2ds["train"].map(as_model_input).shuffle(5000)
dev_ds = split2ds["dev"].map(as_model_input)
history = model.fit(
train_ds.batch(32),
validation_data=dev_ds.batch(32),
callbacks=callbacks,
verbose=2,
epochs=100)
Epoch 1/100 Epoch 00001: val_loss improved from inf to 9.98566, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 104s - loss: 11.0557 - val_loss: 9.9857 Epoch 2/100 Epoch 00002: val_loss improved from 9.98566 to 9.37180, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 103s - loss: 10.2950 - val_loss: 9.3718 Epoch 3/100 Epoch 00003: val_loss improved from 9.37180 to 9.22063, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 104s - loss: 9.9719 - val_loss: 9.2206 Epoch 4/100 Epoch 00004: val_loss improved from 9.22063 to 8.88304, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 106s - loss: 9.8204 - val_loss: 8.8830 Epoch 5/100 Epoch 00005: val_loss improved from 8.88304 to 8.77573, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 105s - loss: 9.6827 - val_loss: 8.7757 Epoch 6/100 Epoch 00006: val_loss improved from 8.77573 to 8.60583, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 104s - loss: 9.5355 - val_loss: 8.6058 Epoch 7/100 Epoch 00007: val_loss improved from 8.60583 to 8.59336, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 105s - loss: 9.4165 - val_loss: 8.5934 Epoch 8/100 Epoch 00008: val_loss improved from 8.59336 to 8.54300, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 105s - loss: 9.3556 - val_loss: 8.5430 Epoch 9/100 Epoch 00009: val_loss improved from 8.54300 to 8.50729, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 105s - loss: 9.2963 - val_loss: 8.5073 Epoch 10/100 Epoch 00010: val_loss improved from 8.50729 to 8.35023, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 104s - loss: 9.2043 - val_loss: 8.3502 Epoch 11/100 Epoch 00011: val_loss improved from 8.35023 to 8.27633, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 104s - loss: 9.1600 - val_loss: 8.2763 Epoch 12/100 Epoch 00012: val_loss did not improve from 8.27633 1368/1368 - 105s - loss: 9.1104 - val_loss: 8.2821 Epoch 13/100 Epoch 00013: val_loss did not improve from 8.27633 1368/1368 - 104s - loss: 9.0874 - val_loss: 8.3520 Epoch 14/100 Epoch 00014: val_loss improved from 8.27633 to 8.26146, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 105s - loss: 9.0329 - val_loss: 8.2615 Epoch 15/100 Epoch 00015: val_loss did not improve from 8.26146 1368/1368 - 105s - loss: 8.9947 - val_loss: 8.3461 Epoch 16/100 Epoch 00016: val_loss did not improve from 8.26146 1368/1368 - 105s - loss: 8.9641 - val_loss: 8.3906 Epoch 17/100 Epoch 00017: val_loss did not improve from 8.26146 1368/1368 - 105s - loss: 8.9377 - val_loss: 8.5105 Epoch 18/100 Epoch 00018: val_loss did not improve from 8.26146 1368/1368 - 105s - loss: 8.9057 - val_loss: 8.2923 Epoch 19/100 Epoch 00019: val_loss did not improve from 8.26146 1368/1368 - 105s - loss: 8.8708 - val_loss: 8.3061 Epoch 20/100 Epoch 00020: val_loss did not improve from 8.26146 1368/1368 - 105s - loss: 8.8503 - val_loss: 8.3104 Epoch 21/100 Epoch 00021: val_loss improved from 8.26146 to 8.04027, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 105s - loss: 8.8293 - val_loss: 8.0403 Epoch 22/100 Epoch 00022: val_loss improved from 8.04027 to 7.99640, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 104s - loss: 8.8023 - val_loss: 7.9964 Epoch 23/100 Epoch 00023: val_loss improved from 7.99640 to 7.94184, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 105s - loss: 8.8169 - val_loss: 7.9418 Epoch 24/100 Epoch 00024: val_loss improved from 7.94184 to 7.93043, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 105s - loss: 8.7631 - val_loss: 7.9304 Epoch 25/100 Epoch 00025: val_loss improved from 7.93043 to 7.82285, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 105s - loss: 8.7701 - val_loss: 7.8228 Epoch 26/100 Epoch 00026: val_loss improved from 7.82285 to 7.78633, saving model to /data/exp/cv4-angular-lstm/cache/model/angular_proximity_lstm 1368/1368 - 105s - loss: 8.7213 - val_loss: 7.7863 Epoch 27/100 Epoch 00027: val_loss did not improve from 7.78633 1368/1368 - 104s - loss: 8.7471 - val_loss: 7.9764 Epoch 28/100 Epoch 00028: val_loss did not improve from 7.78633 1368/1368 - 104s - loss: 8.7360 - val_loss: 7.8850 Epoch 29/100 Epoch 00029: val_loss did not improve from 7.78633 1368/1368 - 104s - loss: 8.7446 - val_loss: 7.9119 Epoch 30/100 Epoch 00030: val_loss did not improve from 7.78633 1368/1368 - 103s - loss: 8.7536 - val_loss: 7.8326 Epoch 31/100 Epoch 00031: val_loss did not improve from 7.78633 1368/1368 - 105s - loss: 8.7375 - val_loss: 7.9426 Epoch 32/100 Epoch 00032: val_loss did not improve from 7.78633 1368/1368 - 105s - loss: 8.7093 - val_loss: 7.8477 Epoch 33/100 Epoch 00033: val_loss did not improve from 7.78633 1368/1368 - 105s - loss: 8.7404 - val_loss: 7.8556 Epoch 34/100 Epoch 00034: val_loss did not improve from 7.78633 1368/1368 - 105s - loss: 8.7481 - val_loss: 7.8946 Epoch 35/100 Epoch 00035: val_loss did not improve from 7.78633 1368/1368 - 104s - loss: 8.7423 - val_loss: 7.9347 Epoch 36/100 Epoch 00036: val_loss did not improve from 7.78633 1368/1368 - 105s - loss: 8.7312 - val_loss: 7.7906
The angular proximity loss function uses reference directions for each language, such that each direction is orthogonal to each other. By selecting the closest reference direction for every predicted language vector, the model can be used as an end-to-end classifier.
import pandas as pd
from lidbox.util import predict_with_model, classification_report
from lidbox.visualize import draw_confusion_matrix
def load_trained_model():
model = create_model()
model.load_weights(os.path.join(cachedir, "model", model.name))
return model
def display_classification_report(report):
for m in ("avg_detection_cost", "avg_equal_error_rate", "accuracy"):
print("{}: {:.3f}".format(m, report[m]))
lang_metrics = pd.DataFrame.from_dict(
{k: v for k, v in report.items() if k in lang2target})
lang_metrics["mean"] = lang_metrics.mean(axis=1)
display(lang_metrics.T)
fig, ax = draw_confusion_matrix(report["confusion_matrix"], lang2target)
model = load_trained_model()
def predict_with_ap_loss(x):
with tf.device("GPU"):
# Generate language vector for input spectra
language_vector = model(x["input"], training=False)
# Predict languages by computing distances to reference directions
return x["id"], model.loss.predict(language_vector)
chunk2pred = predict_with_model(
model=model,
ds=split2ds["test"].map(lambda x: dict(x, input=x["logmelspec"])).batch(128),
predict_fn=predict_with_ap_loss)
We divided all samples into 3.2 second chunks, so all predictions are still for these chunks. Lets merge all chunk predictions by taking the average over all chunks for each sample.
chunk2pred
prediction | |
---|---|
id | |
common_voice_et_18031888-000001 | [-1.4813205, -2.2343419, -1.4084598, -0.8021669] |
common_voice_et_18031888-000002 | [-1.0812643, -2.0329118, -2.169351, -1.1059672] |
common_voice_et_18031888-000003 | [-0.7330606, -2.128548, -1.874414, -1.7434976] |
common_voice_et_18031888-000004 | [-0.75235176, -1.972779, -2.058972, -1.808581] |
common_voice_et_18031889-000001 | [-0.7298598, -2.0762086, -1.5253065, -1.9248742] |
... | ... |
common_voice_tr_22462713-000001 | [-1.9833467, -0.7498288, -1.965956, -1.6560669] |
common_voice_tr_22474271-000001 | [-1.7910697, -1.7766209, -0.43017593, -1.8017486] |
common_voice_tr_22474274-000001 | [-2.0003417, -1.4935714, -0.7071863, -1.6069621] |
common_voice_tr_22477339-000001 | [-1.4122769, -1.6087383, -2.2520356, -0.8622882] |
common_voice_tr_22498670-000001 | [-1.4916769, -1.8612272, -0.6381675, -2.0136924] |
21857 rows × 1 columns
from lidbox.util import merge_chunk_predictions
utt2pred = merge_chunk_predictions(chunk2pred)
utt2pred
prediction | |
---|---|
id | |
common_voice_et_18031888 | [-1.0119994, -2.0921452, -1.8777992, -1.3650532] |
common_voice_et_18031889 | [-1.4600016, -2.0907407, -1.028938, -1.6188763] |
common_voice_et_18031891 | [-1.5245651, -1.776621, -2.0704782, -0.73580074] |
common_voice_et_18038135 | [-1.3484797, -1.2462838, -2.1541154, -1.5248653] |
common_voice_et_18038136 | [-0.8494967, -1.8381963, -2.0412564, -1.3015724] |
... | ... |
common_voice_tr_22462713 | [-1.9833467, -0.7498288, -1.965956, -1.6560669] |
common_voice_tr_22474271 | [-1.7910697, -1.7766209, -0.43017593, -1.8017486] |
common_voice_tr_22474274 | [-2.0003417, -1.4935714, -0.7071863, -1.6069621] |
common_voice_tr_22477339 | [-1.4122769, -1.6087383, -2.2520356, -0.8622882] |
common_voice_tr_22498670 | [-1.4916769, -1.8612272, -0.6381675, -2.0136924] |
7569 rows × 1 columns
test_meta = meta[meta["split"]=="test"].join(utt2pred, how="outer")
assert not test_meta.isna().any(axis=None), "failed to join predictions"
true_sparse = test_meta.target.to_numpy(np.int32)
pred_dense = np.stack(test_meta.prediction)
report = classification_report(true_sparse, pred_dense, lang2target)
display_classification_report(report)
avg_detection_cost: 0.191 avg_equal_error_rate: 0.168 accuracy: 0.683
precision | recall | f1-score | support | equal_error_rate | |
---|---|---|---|---|---|
et | 0.848820 | 0.710028 | 0.773246 | 2483.00 | 0.140779 |
mn | 0.823422 | 0.497238 | 0.620048 | 1810.00 | 0.182150 |
ta | 0.583302 | 0.938339 | 0.719401 | 1638.00 | 0.104873 |
tr | 0.549320 | 0.591575 | 0.569665 | 1638.00 | 0.245996 |
mean | 0.701216 | 0.684295 | 0.670590 | 1892.25 | 0.168449 |
from lidbox.util import model2function
extractor = model2function(load_trained_model())
print("extractor:", str(extractor))
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
2020-11-22 00:13:03.379 W tensorflow: Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
2020-11-22 00:13:03.380 W tensorflow: Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
2020-11-22 00:13:03.381 W tensorflow: Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
2020-11-22 00:13:03.381 W tensorflow: Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
2020-11-22 00:13:03.382 W tensorflow: Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
2020-11-22 00:13:03.383 W tensorflow: A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
extractor: ConcreteFunction <lambda>(x) Args: x: float32 Tensor, shape=(None, None, 40) Returns: float32 Tensor, shape=(None, 800)
from lidbox.visualize import plot_embedding_vector
def is_not_copy(x):
return not x["is_copy"]
def batch_extract_embeddings(x):
with tf.device("GPU"):
return dict(x, embedding=extractor(x["logmelspec"]))
embedding_demo_ds = (split2ds["train"]
.filter(is_not_copy)
.take(12)
.batch(1)
.map(batch_extract_embeddings)
.unbatch())
for x in embedding_demo_ds.as_numpy_iterator():
print(x["id"].decode("utf-8"), x["embedding"].shape)
plot_embedding_vector(x["embedding"], figsize=(10, 0.2))
common_voice_tr_18237877-000001 (800,)
common_voice_tr_18237877-000002 (800,)
common_voice_tr_18237877-000003 (800,)
common_voice_tr_19104341-000001 (800,)
common_voice_tr_19104341-000002 (800,)
common_voice_tr_19104341-000003 (800,)
common_voice_mn_18584433-000001 (800,)
common_voice_mn_18584433-000002 (800,)
common_voice_mn_18584433-000003 (800,)
common_voice_ta_20395073-000001 (800,)
common_voice_ta_20395073-000002 (800,)
common_voice_ta_20395073-000003 (800,)
We'll now extend the existing feature extraction pipeline by adding a step where we extract language vectors with the trained model. In addition, we merge all chunks of each sample by summing over all components of its chunk vectors. The vector is then L2-normalized.
from sklearn.preprocessing import normalize
from lidbox.util import predictions_to_dataframe
# Merge chunk vectors by taking the sum over each component and L2-normalizing the result
def sum_and_normalize(pred):
v = np.stack(pred).sum(axis=0)
v = normalize(v.reshape((1, -1)), axis=1)
return np.squeeze(v)
def ds_to_embeddings(ds):
to_pair = lambda x: (x["id"], x["embedding"])
ds = (ds
.batch(128)
.map(batch_extract_embeddings, num_parallel_calls=TF_AUTOTUNE)
.unbatch()
.map(to_pair, num_parallel_calls=TF_AUTOTUNE))
ids = []
embeddings = []
for id, embedding in ds.as_numpy_iterator():
ids.append(id.decode("utf-8"))
embeddings.append(embedding.astype(np.float32))
df = predictions_to_dataframe(ids, embeddings)
return merge_chunk_predictions(df, merge_rows_fn=sum_and_normalize)
embeddings_by_split = (ds_to_embeddings(ds) for ds in split2ds.values())
m = meta.join(pd.concat(embeddings_by_split, verify_integrity=True), how="outer")
assert not m.prediction.isna().any(axis=None), "Missing embeddings, some rows contained NaN values"
meta = m.rename(columns={"prediction": "embedding"})
Now, let's extract all embeddings and integer targets into NumPy-data and preprocess them with scikit-learn.
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from lidbox.embed.sklearn_utils import PLDA
def embeddings_as_numpy_data(df):
X = np.stack(df.embedding.values).astype(np.float32)
y = df.target.to_numpy(dtype=np.int32)
return X, y
def random_sample(X, y, sample_size_ratio):
N = X.shape[0]
sample_size = int(sample_size_ratio*N)
sample_idx = np_rng.choice(np.arange(N), size=sample_size, replace=False)
return X[sample_idx], y[sample_idx]
def pca_3d_scatterplot_by_label(data, targets, split_name):
target2lang = {t: l for l, t in lang2target.items()}
df = pd.DataFrame.from_dict({
"x": data[:,0],
"y": data[:,1],
"z": data[:,2],
"lang": [target2lang[t] for t in targets],
})
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
for lang, g in df.groupby("lang"):
ax.scatter(g.x, g.y, g.z, label=lang)
ax.legend()
ax.set_title("3D PCA scatter plot of {} set language vectors".format(split_name))
plt.show()
train_X, train_y = embeddings_as_numpy_data(meta[meta["split"]=="train"])
print("training vectors", train_X.shape, train_y.shape)
test_X, test_y = embeddings_as_numpy_data(meta[meta["split"]=="test"])
print("test vectors", test_X.shape, test_y.shape)
# Standardize all vectors using training set statistics
scaler = StandardScaler()
scaler.fit(train_X)
train_X = scaler.transform(train_X)
test_X = scaler.transform(test_X)
# Reduce dimensions
pre_shape = train_X.shape
plda = PLDA()
plda.fit(train_X, train_y)
train_X = plda.transform(train_X)
test_X = plda.transform(test_X)
print("PLDA reduced dimensions from {} to {}".format(pre_shape, train_X.shape))
# L2-normalize vectors to surface of a unit sphere
train_X = normalize(train_X)
test_X = normalize(test_X)
# Map vectors to 3D with PCA, select 10% samples, plot vectors
pca = PCA(n_components=3, whiten=False)
pca.fit(train_X)
X, y = random_sample(pca.transform(train_X), train_y, 0.1)
pca_3d_scatterplot_by_label(X, y, "training")
X, y = random_sample(pca.transform(test_X), test_y, 0.1)
pca_3d_scatterplot_by_label(X, y, "test")
training vectors (16728, 800) (16728,) test vectors (7569, 800) (7569,) PLDA reduced dimensions from (16728, 800) to (16728, 3)
from sklearn.naive_bayes import GaussianNB
from lidbox.util import classification_report
# Fit classifier
clf = GaussianNB()
clf.fit(train_X, train_y)
# Predict scores on test set with classifier and compute metrics
test_pred = clf.predict_log_proba(test_X)
# Clamp -infs to -100
test_pred = np.maximum(-100, test_pred)
report = classification_report(test_y, test_pred, lang2target)
display_classification_report(report)
avg_detection_cost: 0.171 avg_equal_error_rate: 0.157 accuracy: 0.733
precision | recall | f1-score | support | equal_error_rate | |
---|---|---|---|---|---|
et | 0.851627 | 0.790576 | 0.819967 | 2483.00 | 0.124263 |
mn | 0.851100 | 0.555801 | 0.672460 | 1810.00 | 0.192568 |
ta | 0.836235 | 0.791819 | 0.813421 | 1638.00 | 0.098297 |
tr | 0.507309 | 0.783883 | 0.615975 | 1638.00 | 0.212949 |
mean | 0.761568 | 0.730520 | 0.730456 | 1892.25 | 0.157019 |
Compared to the results from our previous examples, we were unable to get better results by training an RNN based model with the angular proximity loss function. However, the PCA scatter plots suggest that language vectors of the same class are much closer to each other compared to what we extracted from the x-vector model.
In any case, we might need much larger datasets before we can reliably compare the x-vector model and the LSTM model we used here.