Open
Show file tree
Hide file tree
Changes from all commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Failed to load files.
Original file line numberDiff line numberDiff line change
Expand Up@@ -386,9 +386,6 @@ def _init_model_source(data):
gcs_tflite_uri = data.pop('gcsTfliteUri', None)
if gcs_tflite_uri:
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
auto_ml_model = data.pop('automlModel', None)
if auto_ml_model:
return TFLiteAutoMlSource(auto_ml_model=auto_ml_model)
return None

@property
Expand DownExpand Up@@ -603,36 +600,6 @@ def as_dict(self, for_upload=False):
return {'gcsTfliteUri': self._gcs_tflite_uri}


class TFLiteAutoMlSource(TFLiteModelSource):
"""TFLite model source representing a tflite model created with AutoML."""

def __init__(self, auto_ml_model, app=None):
self._app = app
self.auto_ml_model = auto_ml_model

def __eq__(self, other):
if isinstance(other, self.__class__):
return self.auto_ml_model == other.auto_ml_model
return False

def __ne__(self, other):
return not self.__eq__(other)

@property
def auto_ml_model(self):
"""Resource name of the model, created by the AutoML API or Cloud console."""
return self._auto_ml_model

@auto_ml_model.setter
def auto_ml_model(self, auto_ml_model):
self._auto_ml_model = _validate_auto_ml_model(auto_ml_model)

def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
# Upload is irrelevant for auto_ml models
return {'automlModel': self._auto_ml_model}


class ListModelsPage:
"""Represents a page of models in a Firebase project.

Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -22,7 +22,6 @@

import pytest

import firebase_admin
from firebase_admin import exceptions
from firebase_admin import ml
from tests import testutils
Expand All@@ -35,12 +34,6 @@
except ImportError:
_TF_ENABLED = False

try:
from google.cloud import automl_v1
_AUTOML_ENABLED = True
except ImportError:
_AUTOML_ENABLED = False

def _random_identifier(prefix):
#pylint: disable=unused-variable
suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)])
Expand DownExpand Up@@ -159,14 +152,6 @@ def check_tflite_gcs_format(model, validation_error=None):
assert model.model_hash is not None


def check_tflite_automl_format(model):
assert model.validation_error is None
assert model.published is False
assert model.model_format.model_source.auto_ml_model.startswith('projects/')
# Automl models don't have validation errors since they are references
# to valid automl models.


@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
def test_create_simple_model(firebase_model):
check_model(firebase_model, NAME_AND_TAGS_ARGS)
Expand DownExpand Up@@ -388,50 +373,3 @@ def test_from_saved_model(saved_model_dir):
assert created_model.validation_error is None
finally:
_clean_up_model(created_model)


# Test AutoML functionality if AutoML is enabled.
#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True
# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the
# successful test. (Test is skipped otherwise)

@pytest.fixture
def automl_model():
assert _AUTOML_ENABLED

# It takes > 20 minutes to train a model, so we expect a predefined AutoMl
# model named 'admin_sdk_integ_test1' to exist in the project, or we skip
# the test.
automl_client = automl_v1.AutoMlClient()
project_id = firebase_admin.get_app().project_id
parent = automl_client.location_path(project_id, 'us-central1')
models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1")
# Expecting exactly one. (Ok to use last one if somehow more than 1)
automl_ref = None
for model in models:
automl_ref = model.name

# Skip if no pre-defined model. (It takes min > 20 minutes to train a model)
if automl_ref is None:
pytest.skip("No pre-existing AutoML model found. Skipping test")

source = ml.TFLiteAutoMlSource(automl_ref)
tflite_format = ml.TFLiteFormat(model_source=source)
ml_model = ml.Model(
display_name=_random_identifier('TestModel_automl_'),
tags=['test_automl'],
model_format=tflite_format)
model = ml.create_model(model=ml_model)
yield model
_clean_up_model(model)

@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.')
def test_automl_model(automl_model):
# This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1'
automl_model.wait_for_unlocked()

check_model(automl_model, {
'display_name': automl_model.display_name,
'tags': ['test_automl'],
})
check_tflite_automl_format(automl_model)