Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions firebase_admin/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class Message(object):
apns: An instance of ``messaging.ApnsConfig`` (optional).
token: The registration token of the device to which the message should be sent (optional).
topic: Name of the FCM topic to which the message should be sent (optional). Topic name
must not contain the ``/topics/`` prefix.
may contain the ``/topics/`` prefix.
condition: The FCM condition to which the message should be sent (optional).
"""

Expand Down Expand Up @@ -671,6 +671,18 @@ def encode_notification(cls, notification):
}
return cls.remove_null_values(result)

@classmethod
def sanitize_topic_name(cls, topic):
if not topic:
return None
prefix = '/topics/'
if topic.startswith(prefix):
topic = topic[len(prefix):]
# Checks for illegal characters and empty string.
if not re.match(r'^[a-zA-Z0-9-_\.~%]+$', topic):
raise ValueError('Malformed topic name.')
return topic

def default(self, obj): # pylint: disable=method-hidden
if not isinstance(obj, Message):
return json.JSONEncoder.default(self, obj)
Expand All @@ -685,16 +697,11 @@ def default(self, obj): # pylint: disable=method-hidden
'topic': _Validators.check_string('Message.topic', obj.topic, non_empty=True),
'webpush': _MessageEncoder.encode_webpush(obj.webpush),
}
result['topic'] = _MessageEncoder.sanitize_topic_name(result.get('topic'))
result = _MessageEncoder.remove_null_values(result)
target_count = sum([t in result for t in ['token', 'topic', 'condition']])
if target_count != 1:
raise ValueError('Exactly one of token, topic or condition must be specified.')
topic = result.get('topic')
if topic:
if topic.startswith('/topics/'):
raise ValueError('Topic name must not contain the /topics/ prefix.')
if not re.match(r'^[a-zA-Z0-9-_\.~%]+$', topic):
raise ValueError('Illegal characters in topic name.')
return result


Expand Down
7 changes: 5 additions & 2 deletions tests/test_messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def test_invalid_condition(self, target):
check_encoding(messaging.Message(condition=target))
assert str(excinfo.value) == 'Message.condition must be a non-empty string.'

@pytest.mark.parametrize('topic', ['/topics/foo', '/foo/bar', 'foo bar'])
def test_topic_name_prefix(self, topic):
@pytest.mark.parametrize('topic', ['/topics/', '/foo/bar', 'foo bar'])
def test_malformed_topic_name(self, topic):
with pytest.raises(ValueError):
check_encoding(messaging.Message(topic=topic))

Expand All @@ -92,6 +92,9 @@ def test_data_message(self):
messaging.Message(topic='topic', data={'k1': 'v1', 'k2': 'v2'}),
{'topic': 'topic', 'data': {'k1': 'v1', 'k2': 'v2'}})

def test_prefixed_topic(self):
check_encoding(messaging.Message(topic='/topics/topic'), {'topic': 'topic'})


class TestNotificationEncoder(object):

Expand Down