File tree

6 files changed

+72
-40
lines changed

6 files changed

+72
-40
lines changed
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
try:
19-
import google.cloud.aiplatform.training_utils.cloud_profiler.initializer as initializer
20-
except ImportError as err:
21-
raise ImportError(
22-
"Could not load the cloud profiler. To use the profiler, "
23-
'install the SDK using "pip install google-cloud-aiplatform[cloud-profiler]"'
24-
) from err
18+
from google.cloud.aiplatform.training_utils.cloud_profiler import initializer
2519

2620
"""
2721
Initialize the cloud profiler for tensorflow.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import_error_msg = (
19+
"Could not load the cloud profiler. To use the profiler, "
20+
"install the SDK using 'pip install google-cloud-aiplatform[cloud-profiler]'"
21+
)
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@
1818
import logging
1919
import threading
2020
from typing import Optional, Type
21-
from werkzeug import serving
21+
22+
from google.cloud.aiplatform.training_utils.cloud_profiler import cloud_profiler_utils
23+
24+
try:
25+
from werkzeug import serving
26+
except ImportError as err:
27+
raise ImportError(cloud_profiler_utils.import_error_msg) from err
28+
2229

2330
from google.cloud.aiplatform.training_utils import environment_variables
2431
from google.cloud.aiplatform.training_utils.cloud_profiler import webserver
@@ -27,6 +34,7 @@
2734
tf_profiler,
2835
)
2936

37+
3038
# Mapping of available plugins to use
3139
_AVAILABLE_PLUGINS = {"tensorflow": tf_profiler.TFProfiler}
3240

Original file line numberDiff line numberDiff line change
@@ -17,14 +17,23 @@
1717

1818
"""A plugin to handle remote tensoflow profiler sessions for Vertex AI."""
1919

20+
from google.cloud.aiplatform.training_utils.cloud_profiler import cloud_profiler_utils
21+
22+
try:
23+
import tensorflow as tf
24+
from tensorboard_plugin_profile.profile_plugin import ProfilePlugin
25+
except ImportError as err:
26+
raise ImportError(cloud_profiler_utils.import_error_msg) from err
27+
2028
import argparse
2129
from collections import namedtuple
2230
import importlib.util
2331
import json
2432
import logging
25-
import tensorboard.plugins.base_plugin as tensorboard_base_plugin
2633
from typing import Callable, Dict, Optional
2734
from urllib import parse
35+
36+
import tensorboard.plugins.base_plugin as tensorboard_base_plugin
2837
from werkzeug import Response
2938

3039
from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader
@@ -54,8 +63,6 @@ def _get_tf_versioning() -> Optional[Version]:
5463
Returns:
5564
A version object if finding the version was successful, None otherwise.
5665
"""
57-
import tensorflow as tf
58-
5966
version = tf.__version__
6067

6168
versioning = version.split(".")
@@ -269,8 +276,6 @@ class TFProfiler(base_plugin.BasePlugin):
269276

270277
def __init__(self):
271278
"""Build a TFProfiler object."""
272-
from tensorboard_plugin_profile.profile_plugin import ProfilePlugin
273-
274279
context = _create_profiling_context()
275280
self._profile_request_sender: profile_uploader.ProfileRequestSender = tensorboard_api.create_profile_request_sender()
276281
self._profile_plugin: ProfilePlugin = ProfilePlugin(context)
@@ -317,20 +322,7 @@ def capture_profile_wrapper(
317322

318323
@staticmethod
319324
def setup() -> None:
320-
"""Sets up the plugin.
321-
322-
Raises:
323-
ImportError: Tensorflow could not be imported.
324-
"""
325-
try:
326-
import tensorflow as tf
327-
except ImportError as err:
328-
raise ImportError(
329-
"Could not import tensorflow for profile usage. "
330-
"To use profiler, install the SDK using "
331-
'"pip install google-cloud-aiplatform[cloud_profiler]"'
332-
) from err
333-
325+
"""Sets up the plugin."""
334326
tf.profiler.experimental.server.start(
335327
int(environment_variables.tf_profiler_port)
336328
)
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@
3636
tensorboard_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
3737
metadata_extra_require = ["pandas >= 1.0.0"]
3838
xai_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
39-
profiler_extra_require = ["tensorboard-plugin-profile", "tensorflow >=2.4.0"]
39+
profiler_extra_require = [
40+
"tensorboard-plugin-profile >= 2.4.0",
41+
"werkzeug >= 2.0.0",
42+
"tensorflow >=2.4.0",
43+
]
4044

4145
full_extra_require = list(
4246
set(tensorboard_extra_require + metadata_extra_require + xai_extra_require)
@@ -84,7 +88,7 @@
8488
"tensorboard": tensorboard_extra_require,
8589
"testing": testing_extra_require,
8690
"xai": xai_extra_require,
87-
"cloud_profiler": profiler_extra_require,
91+
"cloud-profiler": profiler_extra_require,
8892
},
8993
python_requires=">=3.6",
9094
scripts=[],
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
# limitations under the License.
1616
#
1717

18-
from importlib import reload
1918
import importlib.util
2019
import json
20+
import sys
2121
import threading
2222
from typing import List, Optional
2323

@@ -75,6 +75,10 @@ def _create_mock_plugin(
7575
return mock_plugin
7676

7777

78+
def _find_child_modules(root_module):
79+
return [module for module in sys.modules.keys() if module.startswith(root_module)]
80+
81+
7882
@pytest.fixture
7983
def tf_profile_plugin_mock():
8084
"""Mock the tensorboard profile plugin"""
@@ -203,10 +207,6 @@ def testSetup(self):
203207

204208
assert server_mock.call_count == 1
205209

206-
def testSetupRaiseImportError(self):
207-
with mock..dict("sys.modules", {"tensorflow": None}):
208-
self.assertRaises(ImportError, TFProfiler.setup)
209-
210210
def testPostSetupChecksFail(self):
211211
tf_profiler.environment_variables.cluster_spec = {}
212212
assert not TFProfiler.post_setup_check()
@@ -359,13 +359,26 @@ def start_response(status, headers):
359359

360360
# Initializer tests
361361
class TestInitializer(unittest.TestCase):
362-
# Tests for building the plugin
363-
def test_init_failed_import(self):
364-
with mock..dict(
365-
"sys.modules",
366-
{"google.cloud.aiplatform.training_utils.cloud_profiler.initializer": None},
362+
def testImportError(self):
363+
# Unloads any of the cloud profiler sub-modules
364+
for mod in _find_child_modules(
365+
"google.cloud.aiplatform.training_utils.cloud_profiler"
367366
):
368-
self.assertRaises(ImportError, reload, training_utils.cloud_profiler)
367+
del sys.modules[mod]
368+
369+
# Modules to be mocked out
370+
for mock_module in [
371+
"tensorflow",
372+
"tensorboard_plugin_profile.profile_plugin",
373+
"werkzeug",
374+
]:
375+
with self.subTest():
376+
with mock..dict("sys.modules", {mock_module: None}):
377+
with self.assertRaises(ImportError) as cm:
378+
importlib.import_module(
379+
"google.cloud.aiplatform.training_utils.cloud_profiler"
380+
)
381+
assert "Could not load the cloud profiler" in cm.exception.msg
369382

370383
def test_build_plugin_fail_initialize(self):
371384
plugin = _create_mock_plugin()

0 commit comments

Comments
 (0)