Commit 67ac6107 authored by Bannier Delphine's avatar Bannier Delphine
Browse files

add dropout prediction file

parent 33fd0134
from keras.models import Model, Sequential
from keras import backend as K
def create_dropout_predict_function(model, dropout):
"""
Create a keras function to predict with dropout
model : keras model
dropout : fraction dropout to apply to all layers
Returns
predict_with_dropout : keras function for predicting with dropout
"""
# Load the config of the original model
conf = model.get_config()
# Add the specified dropout to all layers
for layer in conf['layers']:
# Dropout layers
if layer["class_name"]=="Dropout":
layer["config"]["rate"] = dropout
# Recurrent layers with dropout
elif "dropout" in layer["config"].keys():
layer["config"]["dropout"] = dropout
# Create a new model with specified dropout
if type(model)==Sequential:
# Sequential
model_dropout = Sequential.from_config(conf)
else:
# Functional
model_dropout = Model.from_config(conf)
model_dropout.set_weights(model.get_weights())
# Create a function to predict with the dropout on
predict_with_dropout = K.function([model_dropout.inputs,K.learning_phase()], model_dropout.outputs)
print(model_dropout.inputs)
return predict_with_dropout
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment