File tree

5 files changed

+325
-6
lines changed

5 files changed

+325
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@
2525
class MetadataBuilder(_ABC):
2626
"""Abstract base class for metadata builders."""
2727

28-
@abc.abstractmethod
29-
def save_model_with_metadata(self, filepath: str):
30-
"""Saves the model with metadata."""
31-
3228
@abc.abstractmethod
3329
def get_metadata(self):
3430
"""Returns the current metadata as a dictionary."""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from google.protobuf import json_format
18+
from typing import Optional, List, Dict, Any, Tuple
19+
20+
from google.cloud.aiplatform.explain.metadata import metadata_builder
21+
from google.cloud.aiplatform.compat.types import (
22+
explanation_metadata_v1beta1 as explanation_metadata,
23+
)
24+
25+
26+
class SavedModelMetadataBuilder(metadata_builder.MetadataBuilder):
27+
"""Class for generating metadata for a model built with TF 2.X Keras API."""
28+
29+
def __init__(
30+
self,
31+
model_path: str,
32+
signature_name: Optional[str] = None,
33+
outputs_to_explain: Optional[List[str]] = None,
34+
**kwargs
35+
) -> None:
36+
"""Initializes a SavedModelMetadataBuilder object.
37+
38+
Args:
39+
model_path:
40+
Required. Path to load the saved model from.
41+
signature_name:
42+
Optional. Name of the signature to be explained. Inputs and
43+
outputs of this signature will be written in the metadata. If not
44+
provided, the default signature will be used.
45+
outputs_to_explain:
46+
Optional. List of output names to explain. Only single output is
47+
supported for now. Hence, the list should contain one element.
48+
This parameter is required if the model signature (provided via
49+
signature_name) specifies multiple outputs.
50+
**kwargs:
51+
Any keyword arguments to be passed to tf.saved_model.save() function.
52+
53+
Raises:
54+
ValueError if outputs_to_explain contains more than 1 element.
55+
ImportError if tf is not imported.
56+
"""
57+
if outputs_to_explain and len(outputs_to_explain) > 1:
58+
raise ValueError(
59+
'"outputs_to_explain" can only contain 1 element.\n'
60+
"Got: %s" % len(outputs_to_explain)
61+
)
62+
self._explain_output = outputs_to_explain
63+
self._saved_model_args = kwargs
64+
65+
try:
66+
import tensorflow as tf
67+
except ImportError:
68+
raise ImportError(
69+
"Tensorflow is not installed and is required to load saved model. "
70+
'Please install the SDK using "pip install google-cloud-aiplatform[full]"'
71+
)
72+
73+
if not signature_name:
74+
signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
75+
self._loaded_model = tf.saved_model.load(model_path)
76+
self._inputs, self._outputs = self._infer_metadata_entries_from_model(
77+
signature_name
78+
)
79+
80+
def _infer_metadata_entries_from_model(
81+
self, signature_name: str
82+
) -> Tuple[
83+
Dict[str, explanation_metadata.ExplanationMetadata.InputMetadata],
84+
Dict[str, explanation_metadata.ExplanationMetadata.OutputMetadata],
85+
]:
86+
"""Infers metadata inputs and outputs.
87+
88+
Args:
89+
signature_name:
90+
Required. Name of the signature to be explained. Inputs and outputs of this signature will be written in the metadata. If not provided, the default signature will be used.
91+
92+
Returns:
93+
Inferred input metadata and output metadata from the model.
94+
95+
Raises:
96+
ValueError if specified name is not found in signature outputs.
97+
"""
98+
99+
loaded_sig = self._loaded_model.signatures[signature_name]
100+
_, input_sig = loaded_sig.structured_input_signature
101+
output_sig = loaded_sig.structured_outputs
102+
input_mds = {}
103+
for name, tensor_spec in input_sig.items():
104+
input_mds[name] = explanation_metadata.ExplanationMetadata.InputMetadata(
105+
input_tensor_name=name,
106+
modality=None if tensor_spec.dtype.is_floating else "categorical",
107+
)
108+
109+
output_mds = {}
110+
for name in output_sig:
111+
if not self._explain_output or self._explain_output[0] == name:
112+
output_mds[
113+
name
114+
] = explanation_metadata.ExplanationMetadata.OutputMetadata(
115+
output_tensor_name=name,
116+
)
117+
break
118+
else:
119+
raise ValueError(
120+
"Specified output name cannot be found in given signature outputs."
121+
)
122+
return input_mds, output_mds
123+
124+
def get_metadata(self) -> Dict[str, Any]:
125+
"""Returns the current metadata as a dictionary.
126+
127+
Returns:
128+
Json format of the explanation metadata.
129+
"""
130+
current_md = explanation_metadata.ExplanationMetadata(
131+
inputs=self._inputs, outputs=self._outputs,
132+
)
133+
return json_format.MessageToDict(current_md._pb)
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,16 @@
2929
with io.open(readme_filename, encoding="utf-8") as readme_file:
3030
readme = readme_file.read()
3131

32-
tensorboard_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
32+
tensorboard_extra_require = [
33+
"tensorflow >=2.3.0, <=2.5.0",
34+
"grpcio~=1.34.0",
35+
"six~=1.15.0",
36+
]
3337
metadata_extra_require = ["pandas >= 1.0.0"]
34-
full_extra_require = tensorboard_extra_require + metadata_extra_require
38+
xai_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
39+
full_extra_require = list(
40+
set(tensorboard_extra_require + metadata_extra_require + xai_extra_require)
41+
)
3542
testing_extra_require = full_extra_require + ["grpcio-testing"]
3643

3744

@@ -69,6 +76,7 @@
6976
"metadata": metadata_extra_require,
7077
"tensorboard": tensorboard_extra_require,
7178
"testing": testing_extra_require,
79+
"xai": xai_extra_require,
7280
},
7381
python_requires=">=3.6",
7482
scripts=[],
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2020 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
19+
import tensorflow as tf
20+
import numpy as np
21+
22+
from google.cloud.aiplatform.explain.metadata.tf.v2 import saved_model_metadata_builder
23+
24+
25+
class SavedModelMetadataBuilderTest(tf.test.TestCase):
26+
def test_get_metadata_sequential(self):
27+
# Set up for the sequential.
28+
self.seq_model = tf.keras.models.Sequential()
29+
self.seq_model.add(tf.keras.layers.Dense(32, activation="relu", input_dim=10))
30+
self.seq_model.add(tf.keras.layers.Dense(32, activation="relu"))
31+
self.seq_model.add(tf.keras.layers.Dense(1, activation="sigmoid"))
32+
self.saved_model_path = self.get_temp_dir()
33+
tf.saved_model.save(self.seq_model, self.saved_model_path)
34+
35+
builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
36+
self.saved_model_path
37+
)
38+
generated_md = builder.get_metadata()
39+
expected_md = {
40+
"outputs": {"dense_2": {"outputTensorName": "dense_2"}},
41+
"inputs": {"dense_input": {"inputTensorName": "dense_input"}},
42+
}
43+
assert expected_md == generated_md
44+
45+
def test_get_metadata_functional(self):
46+
inputs1 = tf.keras.Input(shape=(10,), name="model_input1")
47+
inputs2 = tf.keras.Input(shape=(10,), name="model_input2")
48+
x = tf.keras.layers.Dense(32, activation="relu")(inputs1)
49+
x = tf.keras.layers.Dense(32, activation="relu")(x)
50+
x = tf.keras.layers.concatenate([x, inputs2])
51+
outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)
52+
fun_model = tf.keras.Model(
53+
inputs=[inputs1, inputs2], outputs=outputs, name="fun"
54+
)
55+
model_dir = self.get_temp_dir()
56+
tf.saved_model.save(fun_model, model_dir)
57+
builder = saved_model_metadata_builder.SavedModelMetadataBuilder(model_dir)
58+
generated_md = builder.get_metadata()
59+
expected_md = {
60+
"inputs": {
61+
"model_input1": {"inputTensorName": "model_input1"},
62+
"model_input2": {"inputTensorName": "model_input2"},
63+
},
64+
"outputs": {"dense_2": {"outputTensorName": "dense_2"}},
65+
}
66+
assert expected_md == generated_md
67+
68+
def test_get_metadata_subclassed_model(self):
69+
class MyModel(tf.keras.Model):
70+
def __init__(self, num_classes=2):
71+
super(MyModel, self).__init__(name="my_model")
72+
self.num_classes = num_classes
73+
self.dense_1 = tf.keras.layers.Dense(32, activation="relu")
74+
self.dense_2 = tf.keras.layers.Dense(num_classes, activation="sigmoid")
75+
76+
def call(self, inputs):
77+
x = self.dense_1(inputs)
78+
return self.dense_2(x)
79+
80+
subclassed_model = MyModel()
81+
subclassed_model.compile(loss="categorical_crossentropy")
82+
np.random.seed(0)
83+
x_train = np.random.random((1, 100))
84+
y_train = np.random.randint(2, size=(1, 2))
85+
subclassed_model.fit(x_train, y_train, batch_size=1, epochs=1)
86+
model_dir = self.get_temp_dir()
87+
tf.saved_model.save(subclassed_model, model_dir)
88+
89+
builder = saved_model_metadata_builder.SavedModelMetadataBuilder(model_dir)
90+
generated_md = builder.get_metadata()
91+
expected_md = {
92+
"inputs": {"input_1": {"inputTensorName": "input_1"}},
93+
"outputs": {"output_1": {"outputTensorName": "output_1"}},
94+
}
95+
assert expected_md == generated_md
96+
97+
def test_non_keras_model(self):
98+
class CustomModuleWithOutputName(tf.Module):
99+
def __init__(self):
100+
super(CustomModuleWithOutputName, self).__init__()
101+
self.v = tf.Variable(1.0)
102+
103+
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
104+
def __call__(self, x):
105+
return {"custom_output_name": x * self.v}
106+
107+
module_output = CustomModuleWithOutputName()
108+
call_output = module_output.__call__.get_concrete_function(
109+
tf.TensorSpec(None, tf.float32)
110+
)
111+
model_dir = self.get_temp_dir()
112+
tf.saved_model.save(
113+
module_output, model_dir, signatures={"serving_default": call_output}
114+
)
115+
116+
builder = saved_model_metadata_builder.SavedModelMetadataBuilder(model_dir)
117+
generated_md = builder.get_metadata()
118+
expected_md = {
119+
"inputs": {"x": {"inputTensorName": "x"}},
120+
"outputs": {
121+
"custom_output_name": {"outputTensorName": "custom_output_name"}
122+
},
123+
}
124+
assert expected_md == generated_md
125+
126+
def test_model_with_feature_column(self):
127+
feature_columns = [
128+
tf.feature_column.embedding_column(
129+
tf.feature_column.categorical_column_with_vocabulary_list(
130+
"mode", ["fixed", "normal", "reversible"]
131+
),
132+
dimension=8,
133+
),
134+
tf.feature_column.numeric_column("age"),
135+
]
136+
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
137+
138+
model = tf.keras.Sequential(
139+
[
140+
feature_layer,
141+
tf.keras.layers.Dense(128, activation="relu"),
142+
tf.keras.layers.Dense(1),
143+
]
144+
)
145+
146+
model.compile(
147+
optimizer="adam",
148+
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
149+
metrics=["accuracy"],
150+
)
151+
152+
model.fit(
153+
{"age": np.array([20, 1]), "mode": np.array(["fixed", "normal"])},
154+
np.array([0, 1]),
155+
)
156+
model_dir = self.get_temp_dir()
157+
tf.saved_model.save(model, model_dir)
158+
builder = saved_model_metadata_builder.SavedModelMetadataBuilder(model_dir)
159+
generated_md = builder.get_metadata()
160+
expected_md = {
161+
"inputs": {
162+
"age": {"inputTensorName": "age", "modality": "categorical"},
163+
"mode": {"inputTensorName": "mode", "modality": "categorical"},
164+
},
165+
"outputs": {"output_1": {"outputTensorName": "output_1"}},
166+
}
167+
assert expected_md == generated_md

0 commit comments

Comments
 (0)