|
15 | 15 | # limitations under the License.
|
16 | 16 | #
|
17 | 17 |
|
18 |
| -from importlib import reload |
19 | 18 | import importlib.util
|
20 | 19 | import json
|
| 20 | +import sys |
21 | 21 | import threading
|
22 | 22 | from typing import List, Optional
|
23 | 23 |
|
@@ -75,6 +75,10 @@ def _create_mock_plugin(
|
75 | 75 | return mock_plugin
|
76 | 76 |
|
77 | 77 |
|
| 78 | +def _find_child_modules(root_module): |
| 79 | +return [module for module in sys.modules.keys() if module.startswith(root_module)] |
| 80 | + |
| 81 | + |
78 | 82 | @pytest.fixture
|
79 | 83 | def tf_profile_plugin_mock():
|
80 | 84 | """Mock the tensorboard profile plugin"""
|
@@ -203,10 +207,6 @@ def testSetup(self):
|
203 | 207 |
|
204 | 208 | assert server_mock.call_count == 1
|
205 | 209 |
|
206 |
| -def testSetupRaiseImportError(self): |
207 |
| -with mock..dict("sys.modules", {"tensorflow": None}): |
208 |
| -self.assertRaises(ImportError, TFProfiler.setup) |
209 |
| - |
210 | 210 | def testPostSetupChecksFail(self):
|
211 | 211 | tf_profiler.environment_variables.cluster_spec = {}
|
212 | 212 | assert not TFProfiler.post_setup_check()
|
@@ -359,13 +359,26 @@ def start_response(status, headers):
|
359 | 359 |
|
360 | 360 | # Initializer tests
|
361 | 361 | class TestInitializer(unittest.TestCase):
|
362 |
| -# Tests for building the plugin |
363 |
| -def test_init_failed_import(self): |
364 |
| -with mock..dict( |
365 |
| -"sys.modules", |
366 |
| -{"google.cloud.aiplatform.training_utils.cloud_profiler.initializer": None}, |
| 362 | +def testImportError(self): |
| 363 | +# Unloads any of the cloud profiler sub-modules |
| 364 | +for mod in _find_child_modules( |
| 365 | +"google.cloud.aiplatform.training_utils.cloud_profiler" |
367 | 366 | ):
|
368 |
| -self.assertRaises(ImportError, reload, training_utils.cloud_profiler) |
| 367 | +del sys.modules[mod] |
| 368 | + |
| 369 | +# Modules to be mocked out |
| 370 | +for mock_module in [ |
| 371 | +"tensorflow", |
| 372 | +"tensorboard_plugin_profile.profile_plugin", |
| 373 | +"werkzeug", |
| 374 | +]: |
| 375 | +with self.subTest(): |
| 376 | +with mock..dict("sys.modules", {mock_module: None}): |
| 377 | +with self.assertRaises(ImportError) as cm: |
| 378 | +importlib.import_module( |
| 379 | +"google.cloud.aiplatform.training_utils.cloud_profiler" |
| 380 | +) |
| 381 | +assert "Could not load the cloud profiler" in cm.exception.msg |
369 | 382 |
|
370 | 383 | def test_build_plugin_fail_initialize(self):
|
371 | 384 | plugin = _create_mock_plugin()
|
|
0 commit comments