File tree

6 files changed

+364
-29
lines changed

6 files changed

+364
-29
lines changed
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18+
import importlib
1819
import os
1920
import pickle
2021
import tempfile
@@ -45,12 +46,17 @@
4546
"save_method": "_save_sklearn_model",
4647
"load_method": "_load_sklearn_model",
4748
"model_file": "model.pkl",
48-
}
49+
},
50+
"xgboost": {
51+
"save_method": "_save_xgboost_model",
52+
"load_method": "_load_xgboost_model",
53+
"model_file": "model.bst",
54+
},
4955
}
5056

5157

5258
def save_model(
53-
model: "sklearn.base.BaseEstimator", # noqa: F821
59+
model: Union["sklearn.base.BaseEstimator", "xgb.Booster"], # noqa: F821
5460
artifact_id: Optional[str] = None,
5561
*,
5662
uri: Optional[str] = None,
@@ -63,7 +69,7 @@ def save_model(
6369
) -> google_artifact_schema.ExperimentModel:
6470
"""Saves a ML model into a MLMD artifact.
6571
66-
Supported model frameworks: sklearn.
72+
Supported model frameworks: sklearn, xgboost.
6773
6874
Example usage:
6975
aiplatform.init(project="my-project", location="my-location", staging_bucket="gs://my-bucket")
@@ -72,7 +78,7 @@ def save_model(
7278
aiplatform.save_model(model, "my-sklearn-model")
7379
7480
Args:
75-
model (sklearn.base.BaseEstimator):
81+
model (Union["sklearn.base.BaseEstimator", "xgb.Booster"]):
7682
Required. A machine learning model.
7783
artifact_id (str):
7884
Optional. The resource id of the artifact. This id must be globally unique
@@ -116,10 +122,23 @@ def save_model(
116122
except ImportError:
117123
pass
118124
else:
119-
if isinstance(model, sklearn.base.BaseEstimator):
125+
# An instance of sklearn.base.BaseEstimator might be a sklearn model
126+
# or a xgboost/lightgbm model implemented on top of sklearn.
127+
if isinstance(
128+
model, sklearn.base.BaseEstimator
129+
) and model.__class__.__module__.startswith("sklearn"):
120130
framework_name = "sklearn"
121131
framework_version = sklearn.__version__
122132

133+
try:
134+
import xgboost as xgb
135+
except ImportError:
136+
pass
137+
else:
138+
if isinstance(model, (xgb.Booster, xgb.XGBModel)):
139+
framework_name = "xgboost"
140+
framework_version = xgb.__version__
141+
123142
if framework_name not in _FRAMEWORK_SPECS:
124143
raise ValueError(
125144
f"Model type {model.__class__.__module__}.{model.__class__.__name__} not supported."
@@ -305,9 +324,24 @@ def _save_sklearn_model(
305324
pickle.dump(model, f, protocol=_PICKLE_PROTOCOL)
306325

307326

327+
def _save_xgboost_model(
328+
model: Union["xgb.Booster", "xgb.XGBModel"], # noqa: F821
329+
path: str,
330+
):
331+
"""Saves a xgboost model.
332+
333+
Args:
334+
model (Union[xgb.Booster, xgb.XGBModel]):
335+
Requred. A xgboost model.
336+
path (str):
337+
Required. The local path to save the model.
338+
"""
339+
model.save_model(path)
340+
341+
308342
def load_model(
309343
model: Union[str, google_artifact_schema.ExperimentModel]
310-
) -> "sklearn.base.BaseEstimator": # noqa: F821
344+
) -> Union["sklearn.base.BaseEstimator", "xgb.Booster"]: # noqa: F821
311345
"""Retrieves the original ML model from an ExperimentModel resource.
312346
313347
Args:
@@ -375,7 +409,44 @@ def _load_sklearn_model(
375409
return sk_model
376410

377411

378-
# TODO(b/264893283)
412+
def _load_xgboost_model(
413+
model_file: str,
414+
model_artifact: google_artifact_schema.ExperimentModel,
415+
) -> Union["xgb.Booster", "xgb.XGBModel"]: # noqa: F821
416+
"""Loads a xgboost model from local path.
417+
418+
Args:
419+
model_file (str):
420+
Required. A local model file to load.
421+
model_artifact (google_artifact_schema.ExperimentModel):
422+
Required. The artifact that saved the model.
423+
Returns:
424+
The xgboost model instance.
425+
426+
Raises:
427+
ImportError: if xgboost is not installed.
428+
"""
429+
try:
430+
import xgboost as xgb
431+
except ImportError:
432+
raise ImportError(
433+
"xgboost is not installed and is required for loading models."
434+
) from None
435+
436+
if xgb.__version__ < model_artifact.framework_version:
437+
_LOGGER.warning(
438+
f"The original model was saved via xgboost {model_artifact.framework_version}. "
439+
f"You are using xgboost {xgb.__version__}."
440+
"Attempting to load model..."
441+
)
442+
443+
module, class_name = model_artifact.model_class.rsplit(".", maxsplit=1)
444+
xgb_model = getattr(importlib.import_module(module), class_name)()
445+
xgb_model.load_model(model_file)
446+
447+
return xgb_model
448+
449+
379450
def register_model(
380451
model: Union[str, google_artifact_schema.ExperimentModel],
381452
*,
Original file line numberDiff line numberDiff line change
@@ -1106,7 +1106,7 @@ def log_classification_metrics(
11061106
@_v1_not_supported
11071107
def log_model(
11081108
self,
1109-
model: "sklearn.base.BaseEstimator", # noqa: F821
1109+
model: Union["sklearn.base.BaseEstimator", "xgb.Booster"], # noqa: F821
11101110
artifact_id: Optional[str] = None,
11111111
*,
11121112
uri: Optional[str] = None,
@@ -1121,7 +1121,7 @@ def log_model(
11211121
) -> google_artifact_schema.ExperimentModel:
11221122
"""Saves a ML model into a MLMD artifact and log it to this ExperimentRun.
11231123
1124-
Supported model frameworks: sklearn.
1124+
Supported model frameworks: sklearn, xgboost.
11251125
11261126
Example usage:
11271127
model = LinearRegression()
@@ -1136,7 +1136,7 @@ def log_model(
11361136
aiplatform.log_model(model, "my-sklearn-model")
11371137
11381138
Args:
1139-
model (sklearn.base.BaseEstimator):
1139+
model (Union["sklearn.base.BaseEstimator", "xgb.Booster"]):
11401140
Required. A machine learning model.
11411141
artifact_id (str):
11421142
Optional. The resource id of the artifact. This id must be globally unique
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def log_classification_metrics(
474474

475475
def log_model(
476476
self,
477-
model: "sklearn.base.BaseEstimator", # noqa: F821
477+
model: Union["sklearn.base.BaseEstimator", "xgb.Booster"], # noqa: F821
478478
artifact_id: Optional[str] = None,
479479
*,
480480
uri: Optional[str] = None,
@@ -489,7 +489,7 @@ def log_model(
489489
) -> google_artifact_schema.ExperimentModel:
490490
"""Saves a ML model into a MLMD artifact and log it to this ExperimentRun.
491491
492-
Supported model frameworks: sklearn.
492+
Supported model frameworks: sklearn, xgboost.
493493
494494
Example usage:
495495
model = LinearRegression()
@@ -504,7 +504,7 @@ def log_model(
504504
aiplatform.log_model(model, "my-sklearn-model")
505505
506506
Args:
507-
model (sklearn.base.BaseEstimator):
507+
model (Union["sklearn.base.BaseEstimator", "xgb.Booster"]):
508508
Required. A machine learning model.
509509
artifact_id (str):
510510
Optional. The resource id of the artifact. This id must be globally unique
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616

1717
import copy
18-
from typing import Optional, Dict, List, Sequence
18+
from typing import Optional, Dict, List, Sequence, Union
1919

2020
from google.auth import credentials as auth_credentials
2121
from google.cloud.aiplatform import explain
@@ -742,7 +742,9 @@ def framework_version(self) -> Optional[str]:
742742
def model_class(self) -> Optional[str]:
743743
return self.metadata.get("modelClass")
744744

745-
def load_model(self) -> "sklearn.base.BaseEstimator": # noqa: F821
745+
def load_model(
746+
self,
747+
) -> Union["sklearn.base.BaseEstimator", "xgb.Booster"]: # noqa: F821
746748
"""Retrieves the original ML model from an ExperimentModel.
747749
748750
Example usage:
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
testing_extra_require = (
9494
full_extra_require
9595
+ profiler_extra_require
96-
+ ["grpcio-testing", "pytest-asyncio", "pytest-xdist", "ipython", "kfp"]
96+
+ ["grpcio-testing", "pytest-asyncio", "pytest-xdist", "ipython", "kfp", "xgboost"]
9797
)
9898

9999

0 commit comments

Comments
 (0)