File tree

4 files changed

+258
-1
lines changed

4 files changed

+258
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from google.protobuf import json_format
18+
from typing import Any, Dict, List, Optional
19+
20+
from google.cloud.aiplatform.compat.types import (
21+
explanation_metadata_v1beta1 as explanation_metadata,
22+
)
23+
from google.cloud.aiplatform.explain.metadata import metadata_builder
24+
25+
try:
26+
import tensorflow.compat.v1 as tf
27+
except ImportError:
28+
raise ImportError(
29+
"Tensorflow is not installed and is required to load saved model. "
30+
'Please install the SDK using "pip install google-cloud-aiplatform[full]"'
31+
)
32+
33+
34+
class SavedModelMetadataBuilder(metadata_builder.MetadataBuilder):
35+
"""Metadata builder class that accepts a TF1 saved model."""
36+
37+
def __init__(
38+
self,
39+
model_path: str,
40+
tags: Optional[List[str]] = None,
41+
signature_name: Optional[str] = None,
42+
outputs_to_explain: Optional[List[str]] = None,
43+
) -> None:
44+
"""Initializes a SavedModelMetadataBuilder object.
45+
46+
Args:
47+
model_path:
48+
Required. Path to load the saved model from.
49+
tags:
50+
Optional. Tags to identify the model graph. If None or empty, TensorFlow's default serving tag will be used.
51+
signature_name:
52+
Optional. Name of the signature to be explained. Inputs and
53+
outputs of this signature will be written in the metadata. If not
54+
provided, the default signature will be used.
55+
outputs_to_explain:
56+
Optional. List of output names to explain. Only single output is
57+
supported for now. Hence, the list should contain one element.
58+
This parameter is required if the model signature (provided via
59+
signature_name) specifies multiple outputs.
60+
61+
Raises:
62+
ValueError if outputs_to_explain contains more than 1 element or signature contains multiple outputs.
63+
"""
64+
if outputs_to_explain:
65+
if len(outputs_to_explain) > 1:
66+
raise ValueError(
67+
"Only one output is supported at the moment. "
68+
f"Received: {outputs_to_explain}."
69+
)
70+
self._output_to_explain = next(iter(outputs_to_explain))
71+
72+
if not signature_name:
73+
signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
74+
self._tags = tags or [tf.saved_model.tag_constants.SERVING]
75+
self._graph = tf.Graph()
76+
77+
with self.graph.as_default():
78+
self._session = tf.Session(graph=self.graph)
79+
self._metagraph_def = tf.saved_model.loader.load(
80+
sess=self.session, tags=self._tags, export_dir=model_path
81+
)
82+
if signature_name not in self._metagraph_def.signature_def:
83+
raise ValueError(
84+
f"Serving sigdef key {signature_name} not in " "the signature def."
85+
)
86+
serving_sigdef = self._metagraph_def.signature_def[signature_name]
87+
if not outputs_to_explain:
88+
if len(serving_sigdef.outputs) > 1:
89+
raise ValueError(
90+
"The signature contains multiple outputs. Specify "
91+
'an output via "outputs_to_explain" parameter.'
92+
)
93+
self._output_to_explain = next(iter(serving_sigdef.outputs.keys()))
94+
95+
self._inputs = _create_input_metadata_from_signature(serving_sigdef.inputs)
96+
self._outputs = _create_output_metadata_from_signature(
97+
serving_sigdef.outputs, self._output_to_explain
98+
)
99+
100+
@property
101+
def graph(self) -> tf.Graph:
102+
return self._graph
103+
104+
@property
105+
def session(self) -> tf.Session:
106+
return self._session
107+
108+
def get_metadata(self) -> Dict[str, Any]:
109+
"""Returns the current metadata as a dictionary.
110+
111+
Returns:
112+
Json format of the explanation metadata.
113+
"""
114+
current_md = explanation_metadata.ExplanationMetadata(
115+
inputs=self._inputs, outputs=self._outputs,
116+
)
117+
return json_format.MessageToDict(current_md._pb)
118+
119+
120+
def _create_input_metadata_from_signature(
121+
signature_inputs: Dict[str, tf.Tensor]
122+
) -> Dict[str, explanation_metadata.ExplanationMetadata.InputMetadata]:
123+
"""Creates InputMetadata from signature inputs.
124+
125+
Args:
126+
signature_inputs:
127+
Required. Inputs of the signature to be explained. If not provided, the default signature will be used.
128+
129+
Returns:
130+
Inferred input metadata from the model.
131+
"""
132+
input_mds = {}
133+
for key, tensor in signature_inputs.items():
134+
input_mds[key] = explanation_metadata.ExplanationMetadata.InputMetadata(
135+
input_tensor_name=tensor.name
136+
)
137+
return input_mds
138+
139+
140+
def _create_output_metadata_from_signature(
141+
signature_outputs: Dict[str, tf.Tensor], output_to_explain: Optional[str] = None,
142+
) -> Dict[str, explanation_metadata.ExplanationMetadata.OutputMetadata]:
143+
"""Creates OutputMetadata from signature inputs.
144+
145+
Args:
146+
signature_outputs:
147+
Required. Inputs of the signature to be explained. If not provided, the default signature will be used.
148+
output_to_explain:
149+
Optional. Output name to explain.
150+
151+
Returns:
152+
Inferred output metadata from the model.
153+
"""
154+
output_mds = {}
155+
for key, tensor in signature_outputs.items():
156+
if not output_to_explain or output_to_explain == key:
157+
output_mds[key] = explanation_metadata.ExplanationMetadata.OutputMetadata(
158+
output_tensor_name=tensor.name
159+
)
160+
return output_mds
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2020 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import tensorflow.compat.v1 as tf
19+
20+
from google.cloud.aiplatform.explain.metadata.tf.v1 import saved_model_metadata_builder
21+
22+
23+
class SavedModelMetadataBuilderTF1Test(tf.test.TestCase):
24+
def _set_up(self):
25+
self.sess = tf.Session(graph=tf.Graph())
26+
with self.sess.graph.as_default():
27+
self.x = tf.placeholder(shape=[None, 10], dtype=tf.float32, name="inp")
28+
weights = tf.constant(1.0, shape=(10, 2), name="weights")
29+
bias_weight = tf.constant(1.0, shape=(2,), name="bias")
30+
self.linear_layer = tf.add(tf.matmul(self.x, weights), bias_weight)
31+
self.prediction = tf.nn.relu(self.linear_layer)
32+
# save the model
33+
self.model_path = self.get_temp_dir()
34+
builder = tf.saved_model.builder.SavedModelBuilder(self.model_path)
35+
tensor_info_x = tf.saved_model.utils.build_tensor_info(self.x)
36+
tensor_info_pred = tf.saved_model.utils.build_tensor_info(self.prediction)
37+
tensor_info_lin = tf.saved_model.utils.build_tensor_info(self.linear_layer)
38+
prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
39+
inputs={"x": tensor_info_x},
40+
outputs={"y": tensor_info_pred},
41+
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME,
42+
)
43+
double_output_signature = tf.saved_model.signature_def_utils.build_signature_def(
44+
inputs={"x": tensor_info_x},
45+
outputs={"y": tensor_info_pred, "lin": tensor_info_lin},
46+
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME,
47+
)
48+
49+
builder.add_meta_graph_and_variables(
50+
self.sess,
51+
[tf.saved_model.tag_constants.SERVING],
52+
signature_def_map={
53+
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature,
54+
"double": double_output_signature,
55+
},
56+
)
57+
builder.save()
58+
59+
def test_get_metadata_correct_inputs(self):
60+
self._set_up()
61+
md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
62+
self.model_path, tags=[tf.saved_model.tag_constants.SERVING]
63+
)
64+
expected_md = {
65+
"inputs": {"x": {"inputTensorName": "inp:0"}},
66+
"outputs": {"y": {"outputTensorName": "Relu:0"}},
67+
}
68+
69+
assert md_builder.get_metadata() == expected_md
70+
71+
def test_get_metadata_double_output(self):
72+
self._set_up()
73+
md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
74+
self.model_path, signature_name="double", outputs_to_explain=["lin"]
75+
)
76+
77+
expected_md = {
78+
"inputs": {"x": {"inputTensorName": "inp:0"}},
79+
"outputs": {"lin": {"outputTensorName": "Add:0"}},
80+
}
81+
82+
assert md_builder.get_metadata() == expected_md
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from google.cloud.aiplatform.explain.metadata.tf.v2 import saved_model_metadata_builder
2323

2424

25-
class SavedModelMetadataBuilderTest(tf.test.TestCase):
25+
class SavedModelMetadataBuilderTF2Test(tf.test.TestCase):
2626
def test_get_metadata_sequential(self):
2727
# Set up for the sequential.
2828
self.seq_model = tf.keras.models.Sequential()

0 commit comments

Comments
 (0)