Open
Show file tree
Hide file tree
Changes from all commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Failed to load files.
Original file line numberDiff line numberDiff line change
Expand Up@@ -31,8 +31,8 @@
K = tf.keras.backend
l = tf.keras.layers

deserialize_keras_object = tf.keras.utils.deserialize_keras_object
serialize_keras_object = tf.keras.utils.serialize_keras_object
deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object
serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object


class _TestHelper(object):
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -31,8 +31,8 @@
K = tf.keras.backend
l = tf.keras.layers

deserialize_keras_object = tf.keras.utils.deserialize_keras_object
serialize_keras_object = tf.keras.utils.serialize_keras_object
deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object
serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object


class _TestHelper(object):
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -525,7 +525,7 @@ def _wrap_fixed_range(
'init_min': init_min,
'init_max': init_max,
'narrow_range': narrow_range})
return tf.keras.utils.serialize_keras_object(config)
return tf.keras.utils.legacy.serialize_keras_object(config)


def _is_serialized_node_data(nested):
Expand DownExpand Up@@ -601,8 +601,9 @@ def fix_input_output_range(
init_min=input_min,
init_max=input_max,
narrow_range=narrow_range)
serialized_fixed_input_quantizer = tf.keras.utils.serialize_keras_object(
fixed_input_quantizer)
serialized_fixed_input_quantizer = (
tf.keras.utils.legacy.serialize_keras_object(fixed_input_quantizer)
)

if _is_functional_model(model):
input_layer_list = _nested_to_flatten_node_data_list(config['input_layers'])
Expand DownExpand Up@@ -685,8 +686,9 @@ def remove_input_range(model):
"""
config = model.get_config()
no_input_quantizer = quantizers.NoQuantizer()
serialized_input_quantizer = tf.keras.utils.serialize_keras_object(
no_input_quantizer)
serialized_input_quantizer = tf.keras.utils.legacy.serialize_keras_object(
no_input_quantizer
)

if _is_functional_model(model):
input_layer_list = _nested_to_flatten_node_data_list(config['input_layers'])
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -23,8 +23,8 @@

import tensorflow as tf

deserialize_keras_object = tf.keras.utils.deserialize_keras_object
serialize_keras_object = tf.keras.utils.serialize_keras_object
deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object
serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object


class QuantizeAnnotate(tf.keras.layers.Wrapper):
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -29,8 +29,8 @@
keras = tf.keras
activations = tf.keras.activations
K = tf.keras.backend
deserialize_keras_object = tf.keras.utils.deserialize_keras_object
serialize_keras_object = tf.keras.utils.serialize_keras_object
deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object
serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object

QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation
MovingAverageQuantizer = quantizers.MovingAverageQuantizer
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -259,13 +259,16 @@ def get_output_quantizers(self, layer):

def get_config(self):
return {
'config': tf.keras.utils.serialize_keras_object(self.config),
'config': tf.keras.utils.legacy.serialize_keras_object(self.config),
'num_bits': self.num_bits,
'init_min': self.init_min,
'init_max': self.init_max,
'narrow_range': self.narrow_range}
'narrow_range': self.narrow_range,
}

@classmethod
def from_config(cls, config):
config['config'] = tf.keras.utils.deserialize_keras_object(config['config'])
config['config'] = tf.keras.utils.legacy.deserialize_keras_object(
config['config']
)
return cls(**config)
Original file line numberDiff line numberDiff line change
Expand Up@@ -27,8 +27,8 @@

from tensorflow_model_optimization.python.core.quantization.keras import quantizers

serialize_keras_object = tf.keras.utils.serialize_keras_object
deserialize_keras_object = tf.keras.utils.deserialize_keras_object
serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object
deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object


class QuantizeLayer(tf.keras.layers.Layer):
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -34,8 +34,8 @@
from tensorflow_model_optimization.python.core.keras import utils
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation

deserialize_keras_object = tf.keras.utils.deserialize_keras_object
serialize_keras_object = tf.keras.utils.serialize_keras_object
deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object
serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object


class QuantizeWrapper(tf.keras.layers.Wrapper):
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -26,8 +26,8 @@
from tensorflow_model_optimization.python.core.keras import compat
from tensorflow_model_optimization.python.core.quantization.keras import quantizers

deserialize_keras_object = tf.keras.utils.deserialize_keras_object
serialize_keras_object = tf.keras.utils.serialize_keras_object
deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object
serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object


@parameterized.parameters(
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -242,12 +242,13 @@ def testSerializeDeserialize(self):
sparsity = pruning_schedule.ConstantSparsity(0.7, 10, 20, 10)

config = sparsity.get_config()
sparsity_deserialized = tf.keras.utils.deserialize_keras_object(
sparsity_deserialized = tf.keras.utils.legacy.deserialize_keras_object(
config,
custom_objects={
'ConstantSparsity': pruning_schedule.ConstantSparsity,
'PolynomialDecay': pruning_schedule.PolynomialDecay
})
'PolynomialDecay': pruning_schedule.PolynomialDecay,
},
)

self.assertEqual(sparsity.__dict__, sparsity_deserialized.__dict__)

Expand DownExpand Up@@ -278,12 +279,13 @@ def testSerializeDeserialize(self):
sparsity = pruning_schedule.PolynomialDecay(0.2, 0.6, 10, 20, 5, 10)

config = sparsity.get_config()
sparsity_deserialized = tf.keras.utils.deserialize_keras_object(
sparsity_deserialized = tf.keras.utils.legacy.deserialize_keras_object(
config,
custom_objects={
'ConstantSparsity': pruning_schedule.ConstantSparsity,
'PolynomialDecay': pruning_schedule.PolynomialDecay
})
'PolynomialDecay': pruning_schedule.PolynomialDecay,
},
)

self.assertEqual(sparsity.__dict__, sparsity_deserialized.__dict__)

Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -318,7 +318,7 @@ def from_config(cls, config):
config = config.copy()

pruning_schedule = config.pop('pruning_schedule')
deserialize_keras_object = keras.utils.deserialize_keras_object # pylint: disable=g-import-not-at-top
deserialize_keras_object = keras.utils.legacy.deserialize_keras_object # pylint: disable=g-import-not-at-top
# TODO(pulkitb): This should ideally be fetched from pruning_schedule,
# which should maintain a list of all the pruning_schedules.
custom_objects = {
Expand Down