File tree

6 files changed

+788
-21
lines changed

6 files changed

+788
-21
lines changed
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,17 @@ def _dasard_uri(self) -> Optional[str]:
173173
url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}"
174174
return url
175175

176+
def _log_job_state(self):
177+
"""Helper method to log job state."""
178+
_LOGGER.info(
179+
"%s %s current state:\n%s"
180+
% (
181+
self.__class__.__name__,
182+
self._gca_resource.name,
183+
self._gca_resource.state,
184+
)
185+
)
186+
176187
def _block_until_complete(self):
177188
"""Helper method to block and check on job until complete.
178189
@@ -190,26 +201,13 @@ def _block_until_complete(self):
190201
while self.state not in _JOB_COMPLETE_STATES:
191202
current_time = time.time()
192203
if current_time - previous_time >= log_wait:
193-
_LOGGER.info(
194-
"%s %s current state:\n%s"
195-
% (
196-
self.__class__.__name__,
197-
self._gca_resource.name,
198-
self._gca_resource.state,
199-
)
200-
)
204+
self._log_job_state()
201205
log_wait = min(log_wait * multiplier, max_wait)
202206
previous_time = current_time
203207
time.sleep(wait)
204208

205-
_LOGGER.info(
206-
"%s %s current state:\n%s"
207-
% (
208-
self.__class__.__name__,
209-
self._gca_resource.name,
210-
self._gca_resource.state,
211-
)
212-
)
209+
self._log_job_state()
210+
213211
# Error is only populated when the job state is
214212
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
215213
if self._gca_resource.state in _JOB_ERROR_STATES:
@@ -845,6 +843,63 @@ def __init__(
845843
project=project, location=location
846844
)
847845

846+
self._logged_web_access_uris = set()
847+
848+
@property
849+
def web_access_uris(self) -> Dict[str, Union[str, Dict[str, str]]]:
850+
"""Fetch the runnable job again and return the latest web access uris.
851+
852+
Returns:
853+
(Dict[str, Union[str, Dict[str, str]]]):
854+
Web access uris of the runnable job.
855+
"""
856+
857+
# Fetch the Job again for most up-to-date web access uris
858+
self._sync_gca_resource()
859+
return self._get_web_access_uris()
860+
861+
@abc.abstractmethod
862+
def _get_web_access_uris(self):
863+
"""Helper method to get the web access uris of the runnable job"""
864+
pass
865+
866+
@abc.abstractmethod
867+
def _log_web_access_uris(self):
868+
"""Helper method to log the web access uris of the runnable job"""
869+
pass
870+
871+
def _block_until_complete(self):
872+
"""Helper method to block and check on runnable job until complete.
873+
874+
Raises:
875+
RuntimeError: If job failed or cancelled.
876+
"""
877+
878+
# Used these numbers so failures surface fast
879+
wait = 5 # start at five seconds
880+
log_wait = 5
881+
max_wait = 60 * 5 # 5 minute wait
882+
multiplier = 2 # scale wait by 2 every iteration
883+
884+
previous_time = time.time()
885+
while self.state not in _JOB_COMPLETE_STATES:
886+
current_time = time.time()
887+
if current_time - previous_time >= log_wait:
888+
self._log_job_state()
889+
log_wait = min(log_wait * multiplier, max_wait)
890+
previous_time = current_time
891+
self._log_web_access_uris()
892+
time.sleep(wait)
893+
894+
self._log_job_state()
895+
896+
# Error is only populated when the job state is
897+
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
898+
if self._gca_resource.state in _JOB_ERROR_STATES:
899+
raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error)
900+
else:
901+
_LOGGER.log_action_completed_against_resource("run", "completed", self)
902+
848903
@abc.abstractmethod
849904
def run(self) -> None:
850905
pass
@@ -1046,6 +1101,26 @@ def network(self) -> Optional[str]:
10461101
self._assert_gca_resource_is_available()
10471102
return self._gca_resource.job_spec.network
10481103

1104+
def _get_web_access_uris(self) -> Dict[str, str]:
1105+
"""Helper method to get the web access uris of the custom job
1106+
1107+
Returns:
1108+
(Dict[str, str]):
1109+
Web access uris of the custom job.
1110+
"""
1111+
return dict(self._gca_resource.web_access_uris)
1112+
1113+
def _log_web_access_uris(self):
1114+
"""Helper method to log the web access uris of the custom job"""
1115+
1116+
for worker, uri in self._get_web_access_uris().items():
1117+
if uri not in self._logged_web_access_uris:
1118+
_LOGGER.info(
1119+
"%s %s access the interactive shell terminals for the custom job:\n%s:\n%s"
1120+
% (self.__class__.__name__, self._gca_resource.name, worker, uri,),
1121+
)
1122+
self._logged_web_access_uris.add(uri)
1123+
10491124
@classmethod
10501125
def from_local_script(
10511126
cls,
@@ -1250,6 +1325,7 @@ def run(
12501325
network: Optional[str] = None,
12511326
timeout: Optional[int] = None,
12521327
restart_job_on_worker_restart: bool = False,
1328+
enable_web_access: bool = False,
12531329
tensorboard: Optional[str] = None,
12541330
sync: bool = True,
12551331
) -> None:
@@ -1271,6 +1347,10 @@ def run(
12711347
gets restarted. This feature can be used by
12721348
distributed training jobs that are not resilient
12731349
to workers leaving and joining a job.
1350+
enable_web_access (bool):
1351+
Whether you want Vertex AI to enable interactive shell access
1352+
to training containers.
1353+
https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
12741354
tensorboard (str):
12751355
Optional. The name of a Vertex AI
12761356
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
@@ -1304,6 +1384,9 @@ def run(
13041384
restart_job_on_worker_restart=restart_job_on_worker_restart,
13051385
)
13061386

1387+
if enable_web_access:
1388+
self._gca_resource.job_spec.enable_web_access = enable_web_access
1389+
13071390
if tensorboard:
13081391
v1beta1_gca_resource = gca_custom_job_v1beta1.CustomJob()
13091392
v1beta1_gca_resource._pb.MergeFromString(
@@ -1588,13 +1671,46 @@ def network(self) -> Optional[str]:
15881671
self._assert_gca_resource_is_available()
15891672
return getattr(self._gca_resource.trial_job_spec, "network")
15901673

1674+
def _get_web_access_uris(self) -> Dict[str, Dict[str, str]]:
1675+
"""Helper method to get the web access uris of the hyperparameter job
1676+
1677+
Returns:
1678+
(Dict[str, Dict[str, str]]):
1679+
Web access uris of the hyperparameter job.
1680+
"""
1681+
web_access_uris = dict()
1682+
for trial in self.trials:
1683+
web_access_uris[trial.id] = web_access_uris.get(trial.id, dict())
1684+
for worker, uri in trial.web_access_uris.items():
1685+
web_access_uris[trial.id][worker] = uri
1686+
return web_access_uris
1687+
1688+
def _log_web_access_uris(self):
1689+
"""Helper method to log the web access uris of the hyperparameter job"""
1690+
1691+
for trial_id, trial_web_access_uris in self._get_web_access_uris().items():
1692+
for worker, uri in trial_web_access_uris.items():
1693+
if uri not in self._logged_web_access_uris:
1694+
_LOGGER.info(
1695+
"%s %s access the interactive shell terminals for trial - %s:\n%s:\n%s"
1696+
% (
1697+
self.__class__.__name__,
1698+
self._gca_resource.name,
1699+
trial_id,
1700+
worker,
1701+
uri,
1702+
),
1703+
)
1704+
self._logged_web_access_uris.add(uri)
1705+
15911706
@base.optional_sync()
15921707
def run(
15931708
self,
15941709
service_account: Optional[str] = None,
15951710
network: Optional[str] = None,
15961711
timeout: Optional[int] = None, # seconds
15971712
restart_job_on_worker_restart: bool = False,
1713+
enable_web_access: bool = False,
15981714
tensorboard: Optional[str] = None,
15991715
sync: bool = True,
16001716
) -> None:
@@ -1616,6 +1732,10 @@ def run(
16161732
gets restarted. This feature can be used by
16171733
distributed training jobs that are not resilient
16181734
to workers leaving and joining a job.
1735+
enable_web_access (bool):
1736+
Whether you want Vertex AI to enable interactive shell access
1737+
to training containers.
1738+
https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
16191739
tensorboard (str):
16201740
Optional. The name of a Vertex AI
16211741
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
@@ -1649,6 +1769,9 @@ def run(
16491769
restart_job_on_worker_restart=restart_job_on_worker_restart,
16501770
)
16511771

1772+
if enable_web_access:
1773+
self._gca_resource.trial_job_spec.enable_web_access = enable_web_access
1774+
16521775
if tensorboard:
16531776
v1beta1_gca_resource = (
16541777
gca_hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob()

0 commit comments

Comments
 (0)