diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index 1942fa8593d..28b15ee070e 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -324,6 +324,79 @@ def filter(self, message): group = _Group() + class user(BaseFilter): + """Filters messages to allow only those which are from specified user ID. + + Notes: + Only one of chat_id or username must be used here. + + Args: + user_id(Optional[int|list]): which user ID(s) to allow through. + username(Optional[str|list]): which username(s) to allow through. If username starts + with '@' symbol, it will be ignored. + + Raises: + ValueError + """ + + def __init__(self, user_id=None, username=None): + if not (bool(user_id) ^ bool(username)): + raise ValueError('One and only one of user_id or username must be used') + if user_id is not None and isinstance(user_id, int): + self.user_ids = [user_id] + else: + self.user_ids = user_id + if username is None: + self.usernames = username + elif isinstance(username, str_type): + self.usernames = [username.replace('@', '')] + else: + self.usernames = [user.replace('@', '') for user in username] + + def filter(self, message): + if self.user_ids is not None: + return bool(message.from_user and message.from_user.id in self.user_ids) + else: + # self.usernames is not None + return bool(message.from_user and message.from_user.username and + message.from_user.username in self.usernames) + + class chat(BaseFilter): + """Filters messages to allow only those which are from specified chat ID. + + Notes: + Only one of chat_id or username must be used here. + + Args: + chat_id(Optional[int|list]): which chat ID(s) to allow through. + username(Optional[str|list]): which username(s) to allow through. If username starts + with '@' symbol, it will be ignored. + + Raises: + ValueError + """ + + def __init__(self, chat_id=None, username=None): + if not (bool(chat_id) ^ bool(username)): + raise ValueError('One and only one of chat_id or username must be used') + if chat_id is not None and isinstance(chat_id, int): + self.chat_ids = [chat_id] + else: + self.chat_ids = chat_id + if username is None: + self.usernames = username + elif isinstance(username, str_type): + self.usernames = [username.replace('@', '')] + else: + self.usernames = [chat.replace('@', '') for chat in username] + + def filter(self, message): + if self.chat_ids is not None: + return bool(message.chat_id in self.chat_ids) + else: + # self.usernames is not None + return bool(message.chat.username and message.chat.username in self.usernames) + class _Invoice(BaseFilter): def filter(self, message): diff --git a/tests/test_filters.py b/tests/test_filters.py index e8c4c6637d2..2fcded4bc9b 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -213,6 +213,51 @@ def test_group_fileter(self): self.message.chat.type = "supergroup" self.assertTrue(Filters.group(self.message)) + def test_filters_chat(self): + with self.assertRaisesRegexp(ValueError, 'chat_id or username'): + Filters.chat(chat_id=-1, username='chat') + with self.assertRaisesRegexp(ValueError, 'chat_id or username'): + Filters.chat() + + def test_filters_chat_id(self): + self.assertFalse(Filters.chat(chat_id=-1)(self.message)) + self.message.chat.id = -1 + self.assertTrue(Filters.chat(chat_id=-1)(self.message)) + self.message.chat.id = -2 + self.assertTrue(Filters.chat(chat_id=[-1, -2])(self.message)) + self.assertFalse(Filters.chat(chat_id=-1)(self.message)) + + def test_filters_chat_username(self): + self.assertFalse(Filters.chat(username='chat')(self.message)) + self.message.chat.username = 'chat' + self.assertTrue(Filters.chat(username='@chat')(self.message)) + self.assertTrue(Filters.chat(username='chat')(self.message)) + self.assertTrue(Filters.chat(username=['chat1', 'chat', 'chat2'])(self.message)) + self.assertFalse(Filters.chat(username=['@chat1', 'chat_2'])(self.message)) + + def test_filters_user(self): + with self.assertRaisesRegexp(ValueError, 'user_id or username'): + Filters.user(user_id=1, username='user') + with self.assertRaisesRegexp(ValueError, 'user_id or username'): + Filters.user() + + def test_filters_user_id(self): + self.assertFalse(Filters.user(user_id=1)(self.message)) + self.message.from_user.id = 1 + self.assertTrue(Filters.user(user_id=1)(self.message)) + self.message.from_user.id = 2 + self.assertTrue(Filters.user(user_id=[1, 2])(self.message)) + self.assertFalse(Filters.user(user_id=1)(self.message)) + + def test_filters_username(self): + self.assertFalse(Filters.user(username='user')(self.message)) + self.assertFalse(Filters.user(username='Testuser')(self.message)) + self.message.from_user.username = 'user' + self.assertTrue(Filters.user(username='@user')(self.message)) + self.assertTrue(Filters.user(username='user')(self.message)) + self.assertTrue(Filters.user(username=['user1', 'user', 'user2'])(self.message)) + self.assertFalse(Filters.user(username=['@username', '@user_2'])(self.message)) + def test_and_filters(self): self.message.text = 'test' self.message.forward_date = True