Previous commit
Next commit
refactor: Addressing PR review comments.
  • Loading branch information
@taiseiak
taiseiak committedJan 4, 2022
commit cb3d243d8431d7b7ccf28370078bb92ad3b6dffa
Original file line numberDiff line numberDiff line change
Expand Up@@ -14,14 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Tuple, Union
import os
from typing import Dict, List, Optional, Tuple, Union

try:
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import dtypes as lit_dtypes
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.api import dtypes as lit_dtypes
from lit_nlp import notebook
except ImportError:
raise ImportError(
Expand DownExpand Up@@ -113,7 +113,7 @@ def __init__(
@property
def attribution_explainer(
self,
) -> Union["AttributionExplainer", None]: # noqa: F821
) -> Optional["AttributionExplainer"]: # noqa: F821
"""Gets the attribution explainer property if set."""
return self._attribution_explainer

Expand DownExpand Up@@ -164,6 +164,12 @@ def output_spec(self) -> lit_types.Spec:
return output_spec_dict

def _load_model(self, model: str):
"""Loads a TensorFlow saved model and populates the input and output signature attributes of the class.
Args:
model: Required. A string reference to a TensorFlow saved model directory.
Raises:
ValueError if the model has more than one input tensor or more than one output tensor.
"""
self._loaded_model = tf.saved_model.load(model)
serving_default = self._loaded_model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
Expand All@@ -180,10 +186,15 @@ def _load_model(self, model: str):
def _set_up_attribution_explainer(
self, model: str, attribution_method: str = "integrated_gradients"
):
"""Populates the attribution explainer attribute of the class."""
try:
import explainable_ai_sdk
from explainable_ai_sdk.metadata.tf.v2 import SavedModelMetadataBuilder
except ImportError:
print(
"Skipping explanations because the Explainable AI SDK is not installed."
'Please install the SDK using "pip install explainable-ai-sdk"'
)
return

builder = SavedModelMetadataBuilder(model)
Expand Down