@@ -173,6 +173,17 @@ def _dasard_uri(self) -> Optional[str]:
|
173 | 173 | url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}"
|
174 | 174 | return url
|
175 | 175 |
|
| 176 | +def _log_job_state(self): |
| 177 | +"""Helper method to log job state.""" |
| 178 | +_LOGGER.info( |
| 179 | +"%s %s current state:\n%s" |
| 180 | +% ( |
| 181 | +self.__class__.__name__, |
| 182 | +self._gca_resource.name, |
| 183 | +self._gca_resource.state, |
| 184 | +) |
| 185 | +) |
| 186 | + |
176 | 187 | def _block_until_complete(self):
|
177 | 188 | """Helper method to block and check on job until complete.
|
178 | 189 |
|
@@ -190,26 +201,13 @@ def _block_until_complete(self):
|
190 | 201 | while self.state not in _JOB_COMPLETE_STATES:
|
191 | 202 | current_time = time.time()
|
192 | 203 | if current_time - previous_time >= log_wait:
|
193 |
| -_LOGGER.info( |
194 |
| -"%s %s current state:\n%s" |
195 |
| -% ( |
196 |
| -self.__class__.__name__, |
197 |
| -self._gca_resource.name, |
198 |
| -self._gca_resource.state, |
199 |
| -) |
200 |
| -) |
| 204 | +self._log_job_state() |
201 | 205 | log_wait = min(log_wait * multiplier, max_wait)
|
202 | 206 | previous_time = current_time
|
203 | 207 | time.sleep(wait)
|
204 | 208 |
|
205 |
| -_LOGGER.info( |
206 |
| -"%s %s current state:\n%s" |
207 |
| -% ( |
208 |
| -self.__class__.__name__, |
209 |
| -self._gca_resource.name, |
210 |
| -self._gca_resource.state, |
211 |
| -) |
212 |
| -) |
| 209 | +self._log_job_state() |
| 210 | + |
213 | 211 | # Error is only populated when the job state is
|
214 | 212 | # JOB_STATE_FAILED or JOB_STATE_CANCELLED.
|
215 | 213 | if self._gca_resource.state in _JOB_ERROR_STATES:
|
@@ -845,6 +843,63 @@ def __init__(
|
845 | 843 | project=project, location=location
|
846 | 844 | )
|
847 | 845 |
|
| 846 | +self._logged_web_access_uris = set() |
| 847 | + |
| 848 | +@property |
| 849 | +def web_access_uris(self) -> Dict[str, Union[str, Dict[str, str]]]: |
| 850 | +"""Fetch the runnable job again and return the latest web access uris. |
| 851 | +
|
| 852 | +Returns: |
| 853 | +(Dict[str, Union[str, Dict[str, str]]]): |
| 854 | +Web access uris of the runnable job. |
| 855 | +""" |
| 856 | + |
| 857 | +# Fetch the Job again for most up-to-date web access uris |
| 858 | +self._sync_gca_resource() |
| 859 | +return self._get_web_access_uris() |
| 860 | + |
| 861 | +@abc.abstractmethod |
| 862 | +def _get_web_access_uris(self): |
| 863 | +"""Helper method to get the web access uris of the runnable job""" |
| 864 | +pass |
| 865 | + |
| 866 | +@abc.abstractmethod |
| 867 | +def _log_web_access_uris(self): |
| 868 | +"""Helper method to log the web access uris of the runnable job""" |
| 869 | +pass |
| 870 | + |
| 871 | +def _block_until_complete(self): |
| 872 | +"""Helper method to block and check on runnable job until complete. |
| 873 | +
|
| 874 | +Raises: |
| 875 | +RuntimeError: If job failed or cancelled. |
| 876 | +""" |
| 877 | + |
| 878 | +# Used these numbers so failures surface fast |
| 879 | +wait = 5 # start at five seconds |
| 880 | +log_wait = 5 |
| 881 | +max_wait = 60 * 5 # 5 minute wait |
| 882 | +multiplier = 2 # scale wait by 2 every iteration |
| 883 | + |
| 884 | +previous_time = time.time() |
| 885 | +while self.state not in _JOB_COMPLETE_STATES: |
| 886 | +current_time = time.time() |
| 887 | +if current_time - previous_time >= log_wait: |
| 888 | +self._log_job_state() |
| 889 | +log_wait = min(log_wait * multiplier, max_wait) |
| 890 | +previous_time = current_time |
| 891 | +self._log_web_access_uris() |
| 892 | +time.sleep(wait) |
| 893 | + |
| 894 | +self._log_job_state() |
| 895 | + |
| 896 | +# Error is only populated when the job state is |
| 897 | +# JOB_STATE_FAILED or JOB_STATE_CANCELLED. |
| 898 | +if self._gca_resource.state in _JOB_ERROR_STATES: |
| 899 | +raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error) |
| 900 | +else: |
| 901 | +_LOGGER.log_action_completed_against_resource("run", "completed", self) |
| 902 | + |
848 | 903 | @abc.abstractmethod
|
849 | 904 | def run(self) -> None:
|
850 | 905 | pass
|
@@ -1046,6 +1101,26 @@ def network(self) -> Optional[str]:
|
1046 | 1101 | self._assert_gca_resource_is_available()
|
1047 | 1102 | return self._gca_resource.job_spec.network
|
1048 | 1103 |
|
| 1104 | +def _get_web_access_uris(self) -> Dict[str, str]: |
| 1105 | +"""Helper method to get the web access uris of the custom job |
| 1106 | +
|
| 1107 | +Returns: |
| 1108 | +(Dict[str, str]): |
| 1109 | +Web access uris of the custom job. |
| 1110 | +""" |
| 1111 | +return dict(self._gca_resource.web_access_uris) |
| 1112 | + |
| 1113 | +def _log_web_access_uris(self): |
| 1114 | +"""Helper method to log the web access uris of the custom job""" |
| 1115 | + |
| 1116 | +for worker, uri in self._get_web_access_uris().items(): |
| 1117 | +if uri not in self._logged_web_access_uris: |
| 1118 | +_LOGGER.info( |
| 1119 | +"%s %s access the interactive shell terminals for the custom job:\n%s:\n%s" |
| 1120 | +% (self.__class__.__name__, self._gca_resource.name, worker, uri,), |
| 1121 | +) |
| 1122 | +self._logged_web_access_uris.add(uri) |
| 1123 | + |
1049 | 1124 | @classmethod
|
1050 | 1125 | def from_local_script(
|
1051 | 1126 | cls,
|
@@ -1250,6 +1325,7 @@ def run(
|
1250 | 1325 | network: Optional[str] = None,
|
1251 | 1326 | timeout: Optional[int] = None,
|
1252 | 1327 | restart_job_on_worker_restart: bool = False,
|
| 1328 | +enable_web_access: bool = False, |
1253 | 1329 | tensorboard: Optional[str] = None,
|
1254 | 1330 | sync: bool = True,
|
1255 | 1331 | ) -> None:
|
@@ -1271,6 +1347,10 @@ def run(
|
1271 | 1347 | gets restarted. This feature can be used by
|
1272 | 1348 | distributed training jobs that are not resilient
|
1273 | 1349 | to workers leaving and joining a job.
|
| 1350 | +enable_web_access (bool): |
| 1351 | +Whether you want Vertex AI to enable interactive shell access |
| 1352 | +to training containers. |
| 1353 | +https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell |
1274 | 1354 | tensorboard (str):
|
1275 | 1355 | Optional. The name of a Vertex AI
|
1276 | 1356 | [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
|
@@ -1304,6 +1384,9 @@ def run(
|
1304 | 1384 | restart_job_on_worker_restart=restart_job_on_worker_restart,
|
1305 | 1385 | )
|
1306 | 1386 |
|
| 1387 | +if enable_web_access: |
| 1388 | +self._gca_resource.job_spec.enable_web_access = enable_web_access |
| 1389 | + |
1307 | 1390 | if tensorboard:
|
1308 | 1391 | v1beta1_gca_resource = gca_custom_job_v1beta1.CustomJob()
|
1309 | 1392 | v1beta1_gca_resource._pb.MergeFromString(
|
@@ -1588,13 +1671,46 @@ def network(self) -> Optional[str]:
|
1588 | 1671 | self._assert_gca_resource_is_available()
|
1589 | 1672 | return getattr(self._gca_resource.trial_job_spec, "network")
|
1590 | 1673 |
|
| 1674 | +def _get_web_access_uris(self) -> Dict[str, Dict[str, str]]: |
| 1675 | +"""Helper method to get the web access uris of the hyperparameter job |
| 1676 | +
|
| 1677 | +Returns: |
| 1678 | +(Dict[str, Dict[str, str]]): |
| 1679 | +Web access uris of the hyperparameter job. |
| 1680 | +""" |
| 1681 | +web_access_uris = dict() |
| 1682 | +for trial in self.trials: |
| 1683 | +web_access_uris[trial.id] = web_access_uris.get(trial.id, dict()) |
| 1684 | +for worker, uri in trial.web_access_uris.items(): |
| 1685 | +web_access_uris[trial.id][worker] = uri |
| 1686 | +return web_access_uris |
| 1687 | + |
| 1688 | +def _log_web_access_uris(self): |
| 1689 | +"""Helper method to log the web access uris of the hyperparameter job""" |
| 1690 | + |
| 1691 | +for trial_id, trial_web_access_uris in self._get_web_access_uris().items(): |
| 1692 | +for worker, uri in trial_web_access_uris.items(): |
| 1693 | +if uri not in self._logged_web_access_uris: |
| 1694 | +_LOGGER.info( |
| 1695 | +"%s %s access the interactive shell terminals for trial - %s:\n%s:\n%s" |
| 1696 | +% ( |
| 1697 | +self.__class__.__name__, |
| 1698 | +self._gca_resource.name, |
| 1699 | +trial_id, |
| 1700 | +worker, |
| 1701 | +uri, |
| 1702 | +), |
| 1703 | +) |
| 1704 | +self._logged_web_access_uris.add(uri) |
| 1705 | + |
1591 | 1706 | @base.optional_sync()
|
1592 | 1707 | def run(
|
1593 | 1708 | self,
|
1594 | 1709 | service_account: Optional[str] = None,
|
1595 | 1710 | network: Optional[str] = None,
|
1596 | 1711 | timeout: Optional[int] = None, # seconds
|
1597 | 1712 | restart_job_on_worker_restart: bool = False,
|
| 1713 | +enable_web_access: bool = False, |
1598 | 1714 | tensorboard: Optional[str] = None,
|
1599 | 1715 | sync: bool = True,
|
1600 | 1716 | ) -> None:
|
@@ -1616,6 +1732,10 @@ def run(
|
1616 | 1732 | gets restarted. This feature can be used by
|
1617 | 1733 | distributed training jobs that are not resilient
|
1618 | 1734 | to workers leaving and joining a job.
|
| 1735 | +enable_web_access (bool): |
| 1736 | +Whether you want Vertex AI to enable interactive shell access |
| 1737 | +to training containers. |
| 1738 | +https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell |
1619 | 1739 | tensorboard (str):
|
1620 | 1740 | Optional. The name of a Vertex AI
|
1621 | 1741 | [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
|
@@ -1649,6 +1769,9 @@ def run(
|
1649 | 1769 | restart_job_on_worker_restart=restart_job_on_worker_restart,
|
1650 | 1770 | )
|
1651 | 1771 |
|
| 1772 | +if enable_web_access: |
| 1773 | +self._gca_resource.trial_job_spec.enable_web_access = enable_web_access |
| 1774 | + |
1652 | 1775 | if tensorboard:
|
1653 | 1776 | v1beta1_gca_resource = (
|
1654 | 1777 | gca_hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob()
|
|
0 commit comments