Open
Show file tree
Hide file tree
Changes from all commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Failed to load files.
Original file line numberDiff line numberDiff line change
Expand Up@@ -15,6 +15,7 @@
"""Internal utilities common to all modules."""

import json
from typing import Callable, Optional

import google.auth
import requests
Expand DownExpand Up@@ -76,7 +77,7 @@
}


def _get_initialized_app(app):
def _get_initialized_app(app: Optional[firebase_admin.App]):
"""Returns a reference to an initialized App instance."""
if app is None:
return firebase_admin.get_app()
Expand All@@ -92,10 +93,9 @@ def _get_initialized_app(app):
' firebase_admin.App, but given "{0}".'.format(type(app)))



def get_app_service(app, name, initializer):
def get_app_service(app: Optional[firebase_admin.App], name: str, initializer: Callable):
app = _get_initialized_app(app)
return app._get_service(name, initializer) # pylint: disable=protected-access
return app._get_service(name, initializer) # pylint: disable=protected-access


def handle_platform_error_from_requests(error, handle_func=None):
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -95,6 +95,7 @@
def _get_messaging_service(app):
return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService)


def send(message, dry_run=False, app=None):
"""Sends the given message via Firebase Cloud Messaging (FCM).

Expand All@@ -115,6 +116,7 @@ def send(message, dry_run=False, app=None):
"""
return _get_messaging_service(app).send(message, dry_run)


def send_all(messages, dry_run=False, app=None):
"""Sends the given list of messages via Firebase Cloud Messaging as a single batch.

Expand All@@ -135,6 +137,7 @@ def send_all(messages, dry_run=False, app=None):
"""
return _get_messaging_service(app).send_all(messages, dry_run)


def send_multicast(multicast_message, dry_run=False, app=None):
"""Sends the given mutlicast message to all tokens via Firebase Cloud Messaging (FCM).

Expand DownExpand Up@@ -166,6 +169,7 @@ def send_multicast(multicast_message, dry_run=False, app=None):
) for token in multicast_message.tokens]
return _get_messaging_service(app).send_all(messages, dry_run)


def subscribe_to_topic(tokens, topic, app=None):
"""Subscribes a list of registration tokens to an FCM topic.

Expand All@@ -185,6 +189,7 @@ def subscribe_to_topic(tokens, topic, app=None):
return _get_messaging_service(app).make_topic_management_request(
tokens, topic, 'iid/v1:batchAdd')


def unsubscribe_from_topic(tokens, topic, app=None):
"""Unsubscribes a list of registration tokens from an FCM topic.

Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -211,13 +211,13 @@ def from_dict(cls, data, app=None):
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
model = Model(model_format=tflite_format)
model._data = data_copy # pylint: disable=protected-access
model._app = app # pylint: disable=protected-access
model._app = app # pylint: disable=protected-access
return model

def _update_from_dict(self, data):
copy = Model.from_dict(data)
self.model_format = copy.model_format
self._data = copy._data # pylint: disable=protected-access
self._data = copy._data # pylint: disable=protected-access

def __eq__(self, other):
if isinstance(other, self.__class__):
Expand DownExpand Up@@ -334,7 +334,7 @@ def model_format(self):
def model_format(self, model_format):
if model_format is not None:
_validate_model_format(model_format)
self._model_format = model_format #Can be None
self._model_format = model_format # Can be None
return self

def as_dict(self, for_upload=False):
Expand DownExpand Up@@ -370,7 +370,7 @@ def from_dict(cls, data):
"""Create an instance of the object from a dict."""
data_copy = dict(data)
tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy))
tflite_format._data = data_copy # pylint: disable=protected-access
tflite_format._data = data_copy # pylint: disable=protected-access
return tflite_format

def __eq__(self, other):
Expand DownExpand Up@@ -405,7 +405,7 @@ def model_source(self, model_source):
if model_source is not None:
if not isinstance(model_source, TFLiteModelSource):
raise TypeError('Model source must be a TFLiteModelSource object.')
self._model_source = model_source # Can be None
self._model_source = model_source # Can be None

@property
def size_bytes(self):
Expand DownExpand Up@@ -485,7 +485,7 @@ def __init__(self, gcs_tflite_uri, app=None):

def __eq__(self, other):
if isinstance(other, self.__class__):
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
return False

def __ne__(self, other):
Expand DownExpand Up@@ -775,7 +775,7 @@ def _validate_display_name(display_name):

def _validate_tags(tags):
if not isinstance(tags, list) or not \
all(isinstance(tag, str) for tag in tags):
all(isinstance(tag, str) for tag in tags):
raise TypeError('Tags must be a list of strings.')
if not all(_TAG_PATTERN.match(tag) for tag in tags):
raise ValueError('Tag format is invalid.')
Expand All@@ -789,6 +789,7 @@ def _validate_gcs_tflite_uri(uri):
raise ValueError('GCS TFLite URI format is invalid.')
return uri


def _validate_auto_ml_model(model):
if not _AUTO_ML_MODEL_PATTERN.match(model):
raise ValueError('Model resource name format is invalid.')
Expand All@@ -809,7 +810,7 @@ def _validate_list_filter(list_filter):

def _validate_page_size(page_size):
if page_size is not None:
if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck
if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck
# Specifically type() to disallow boolean which is a subtype of int
raise TypeError('Page size must be a number or None.')
if page_size < 1 or page_size > _MAX_PAGE_SIZE:
Expand DownExpand Up@@ -864,7 +865,7 @@ def _exponential_backoff(self, current_attempt, stop_time):

if stop_time is not None:
max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds()
if max_seconds_left < 1: # allow a bit of time for rpc
if max_seconds_left < 1: # allow a bit of time for rpc
raise exceptions.DeadlineExceededError('Polling max time exceeded.')
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
time.sleep(wait_time_seconds)
Expand DownExpand Up@@ -925,7 +926,6 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
# If the operation is not complete or timed out, return a (locked) model instead
return get_model(model_id).as_dict()


def create_model(self, model):
_validate_model(model)
try:
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -25,12 +25,14 @@
raise ImportError('Failed to import the Cloud Storage library for Python. Make sure '
'to install the "google-cloud-storage" module.')

from firebase_admin import _utils
from firebase_admin import _utils, App
from typing import Optional


_STORAGE_ATTRIBUTE = '_storage'

def bucket(name=None, app=None) -> storage.Bucket:

def bucket(name: Optional[str] = None, app: Optional[App] = None) -> storage.Bucket:
"""Returns a handle to a Google Cloud Storage bucket.

If the name argument is not provided, uses the 'storageBucket' option specified when
Expand DownExpand Up@@ -67,7 +69,7 @@ def from_app(cls, app):
# significantly speeds up the initialization of the storage client.
return _StorageClient(credentials, app.project_id, default_bucket)

def bucket(self, name=None):
def bucket(self, name: Optional[str] = None):
"""Returns a handle to the specified Cloud Storage Bucket."""
bucket_name = name if name is not None else self._default_bucket
if bucket_name is None:
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -183,6 +183,7 @@ def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=Non
FirebaseError: If an error occurs while retrieving the user accounts.
"""
tenant_mgt_service = _get_tenant_mgt_service(app)

def download(page_token, max_results):
return tenant_mgt_service.list_tenants(page_token, max_results)
return ListTenantsPage(download, page_token, max_results)
Expand All@@ -206,7 +207,7 @@ class Tenant:
def __init__(self, data):
if not isinstance(data, dict):
raise ValueError('Invalid data argument in Tenant constructor: {0}'.format(data))
if not 'name' in data:
if 'name' not in data:
raise ValueError('Tenant response missing required keys.')

self._data = data
Expand DownExpand Up@@ -256,7 +257,7 @@ def auth_for_tenant(self, tenant_id):

client = auth.Client(self.app, tenant_id=tenant_id)
self.tenant_clients[tenant_id] = client
return client
return client

def get_tenant(self, tenant_id):
"""Gets the tenant corresponding to the given ``tenant_id``."""
Expand Down