Commit f2408d41 authored by David Maxence's avatar David Maxence
Browse files

ajout data augmentation

parent ea890dc4
%% Cell type:code id:personal-marshall tags:
``` python
import numpy as np
import os
import pandas as pd
from scipy.io import wavfile
import librosa
from tqdm import tqdm
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from tensorflow.keras import regularizers, activations
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation, Conv2D, MaxPooling2D, GlobalAveragePooling2D
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from datetime import datetime
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
import cv2
```
%% Cell type:code id:intimate-property tags:
``` python
us8k_df = pd.read_pickle("us8k_df.pkl")
```
%% Cell type:code id:closed-abraham tags:
``` python
df = us8k_df.drop(['fold'],axis=1)
X = np.stack(df.melspectrogram.to_numpy())
X_dim = (128,128,1)
X = X.reshape(X.shape[0], *X_dim)
Y = np.array(df['label'])
Y = to_categorical(Y)
```
%% Cell type:code id:controlling-overhead tags:
``` python
Y.shape, X.shape
```
%%%% Output: execute_result
((8732, 10), (8732, 128, 128, 1))
%% Cell type:code id:dying-cooking tags:
``` python
X_new = np.zeros((8732,128,128,3))
```
%% Cell type:code id:increasing-annotation tags:
``` python
X_new.shape
```
%%%% Output: execute_result
(8732, 128, 128, 3)
%% Cell type:code id:rubber-hanging tags:
``` python
for i in range(len(X)):
X_new[i]=cv2.cvtColor(X[i], cv2.COLOR_GRAY2RGB)
```
%% Cell type:code id:after-driving tags:
``` python
X=X_new
```
%% Cell type:code id:illegal-partner tags:
``` python
X_train, X_test, Y_train, Y_test = train_test_split(X,Y,test_size=0.3,shuffle=True,stratify = Y)
X_val, X_test, Y_val,Y_test = train_test_split(X_test,Y_test,test_size=0.5,shuffle=True,stratify = Y_test)
```
%% Cell type:markdown id:french-assembly tags:
MOdèle utilisant mobile net
%% Cell type:code id:geographic-diving tags:
``` python
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
rescale = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1)
```
%% Cell type:code id:pacific-registration tags:
``` python
IMG_SHAPE = (128,128,3)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
```
%% Cell type:code id:included-mozambique tags:
``` python
prediction_layer = tf.keras.layers.Dense(10)
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
```
%% Cell type:code id:palestinian-crowd tags:
``` python
# model : MobileNet puis du dropout et une couche dense pour la prédiction
inputs = tf.keras.Input(shape=IMG_SHAPE)
x = preprocess_input(inputs)
x = rescale(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
```
%% Cell type:code id:simple-rebecca tags:
``` python
model.summary()
```
%% Cell type:code id:pacific-correction tags:
``` python
base_learning_rate = 0.001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
```
%% Cell type:code id:durable-straight tags:
``` python
initial_epochs = 10
num_batch_size = 32
loss0, accuracy0 = model.evaluate(X_val,Y_val)
```
%% Cell type:code id:victorian-hawaiian tags:
``` python
log_dir = "logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir = log_dir, histogram_freq = 1)
save_best = tf.keras.callbacks.ModelCheckpoint(filepath = "logs/checkpoints/", save_weights_only = True,
monitor = "val_accuracy", mode = "max", save_best_only = True)
```
%% Cell type:code id:adolescent-prerequisite tags:
``` python
model_fit = model.fit(X_train[:2000],Y_train[0:2000], epochs=initial_epochs,validation_data=(X_val[0:300],Y_val[0:300]),batch_size=num_batch_size,callbacks = [tensorboard_callback, save_best])
```
%% Cell type:code id:civic-trinity tags:
``` python
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment