|
15 | 15 | # limitations under the License.
|
16 | 16 | #
|
17 | 17 |
|
| 18 | +import importlib |
18 | 19 | import os
|
19 | 20 | import pickle
|
20 | 21 | import tempfile
|
|
45 | 46 | "save_method": "_save_sklearn_model",
|
46 | 47 | "load_method": "_load_sklearn_model",
|
47 | 48 | "model_file": "model.pkl",
|
48 |
| -} |
| 49 | +}, |
| 50 | +"xgboost": { |
| 51 | +"save_method": "_save_xgboost_model", |
| 52 | +"load_method": "_load_xgboost_model", |
| 53 | +"model_file": "model.bst", |
| 54 | +}, |
49 | 55 | }
|
50 | 56 |
|
51 | 57 |
|
52 | 58 | def save_model(
|
53 |
| -model: "sklearn.base.BaseEstimator", # noqa: F821 |
| 59 | +model: Union["sklearn.base.BaseEstimator", "xgb.Booster"], # noqa: F821 |
54 | 60 | artifact_id: Optional[str] = None,
|
55 | 61 | *,
|
56 | 62 | uri: Optional[str] = None,
|
@@ -63,7 +69,7 @@ def save_model(
|
63 | 69 | ) -> google_artifact_schema.ExperimentModel:
|
64 | 70 | """Saves a ML model into a MLMD artifact.
|
65 | 71 |
|
66 |
| -Supported model frameworks: sklearn. |
| 72 | +Supported model frameworks: sklearn, xgboost. |
67 | 73 |
|
68 | 74 | Example usage:
|
69 | 75 | aiplatform.init(project="my-project", location="my-location", staging_bucket="gs://my-bucket")
|
@@ -72,7 +78,7 @@ def save_model(
|
72 | 78 | aiplatform.save_model(model, "my-sklearn-model")
|
73 | 79 |
|
74 | 80 | Args:
|
75 |
| -model (sklearn.base.BaseEstimator): |
| 81 | +model (Union["sklearn.base.BaseEstimator", "xgb.Booster"]): |
76 | 82 | Required. A machine learning model.
|
77 | 83 | artifact_id (str):
|
78 | 84 | Optional. The resource id of the artifact. This id must be globally unique
|
@@ -116,10 +122,23 @@ def save_model(
|
116 | 122 | except ImportError:
|
117 | 123 | pass
|
118 | 124 | else:
|
119 |
| -if isinstance(model, sklearn.base.BaseEstimator): |
| 125 | +# An instance of sklearn.base.BaseEstimator might be a sklearn model |
| 126 | +# or a xgboost/lightgbm model implemented on top of sklearn. |
| 127 | +if isinstance( |
| 128 | +model, sklearn.base.BaseEstimator |
| 129 | +) and model.__class__.__module__.startswith("sklearn"): |
120 | 130 | framework_name = "sklearn"
|
121 | 131 | framework_version = sklearn.__version__
|
122 | 132 |
|
| 133 | +try: |
| 134 | +import xgboost as xgb |
| 135 | +except ImportError: |
| 136 | +pass |
| 137 | +else: |
| 138 | +if isinstance(model, (xgb.Booster, xgb.XGBModel)): |
| 139 | +framework_name = "xgboost" |
| 140 | +framework_version = xgb.__version__ |
| 141 | + |
123 | 142 | if framework_name not in _FRAMEWORK_SPECS:
|
124 | 143 | raise ValueError(
|
125 | 144 | f"Model type {model.__class__.__module__}.{model.__class__.__name__} not supported."
|
@@ -305,9 +324,24 @@ def _save_sklearn_model(
|
305 | 324 | pickle.dump(model, f, protocol=_PICKLE_PROTOCOL)
|
306 | 325 |
|
307 | 326 |
|
| 327 | +def _save_xgboost_model( |
| 328 | +model: Union["xgb.Booster", "xgb.XGBModel"], # noqa: F821 |
| 329 | +path: str, |
| 330 | +): |
| 331 | +"""Saves a xgboost model. |
| 332 | +
|
| 333 | +Args: |
| 334 | +model (Union[xgb.Booster, xgb.XGBModel]): |
| 335 | +Requred. A xgboost model. |
| 336 | +path (str): |
| 337 | +Required. The local path to save the model. |
| 338 | +""" |
| 339 | +model.save_model(path) |
| 340 | + |
| 341 | + |
308 | 342 | def load_model(
|
309 | 343 | model: Union[str, google_artifact_schema.ExperimentModel]
|
310 |
| -) -> "sklearn.base.BaseEstimator": # noqa: F821 |
| 344 | +) -> Union["sklearn.base.BaseEstimator", "xgb.Booster"]: # noqa: F821 |
311 | 345 | """Retrieves the original ML model from an ExperimentModel resource.
|
312 | 346 |
|
313 | 347 | Args:
|
@@ -375,7 +409,44 @@ def _load_sklearn_model(
|
375 | 409 | return sk_model
|
376 | 410 |
|
377 | 411 |
|
378 |
| -# TODO(b/264893283) |
| 412 | +def _load_xgboost_model( |
| 413 | +model_file: str, |
| 414 | +model_artifact: google_artifact_schema.ExperimentModel, |
| 415 | +) -> Union["xgb.Booster", "xgb.XGBModel"]: # noqa: F821 |
| 416 | +"""Loads a xgboost model from local path. |
| 417 | +
|
| 418 | +Args: |
| 419 | +model_file (str): |
| 420 | +Required. A local model file to load. |
| 421 | +model_artifact (google_artifact_schema.ExperimentModel): |
| 422 | +Required. The artifact that saved the model. |
| 423 | +Returns: |
| 424 | +The xgboost model instance. |
| 425 | +
|
| 426 | +Raises: |
| 427 | +ImportError: if xgboost is not installed. |
| 428 | +""" |
| 429 | +try: |
| 430 | +import xgboost as xgb |
| 431 | +except ImportError: |
| 432 | +raise ImportError( |
| 433 | +"xgboost is not installed and is required for loading models." |
| 434 | +) from None |
| 435 | + |
| 436 | +if xgb.__version__ < model_artifact.framework_version: |
| 437 | +_LOGGER.warning( |
| 438 | +f"The original model was saved via xgboost {model_artifact.framework_version}. " |
| 439 | +f"You are using xgboost {xgb.__version__}." |
| 440 | +"Attempting to load model..." |
| 441 | +) |
| 442 | + |
| 443 | +module, class_name = model_artifact.model_class.rsplit(".", maxsplit=1) |
| 444 | +xgb_model = getattr(importlib.import_module(module), class_name)() |
| 445 | +xgb_model.load_model(model_file) |
| 446 | + |
| 447 | +return xgb_model |
| 448 | + |
| 449 | + |
379 | 450 | def register_model(
|
380 | 451 | model: Union[str, google_artifact_schema.ExperimentModel],
|
381 | 452 | *,
|
|
0 commit comments