diff --git a/channels/binding/base.py b/channels/binding/base.py index 1ec92a3..902b56e 100644 --- a/channels/binding/base.py +++ b/channels/binding/base.py @@ -3,12 +3,17 @@ from __future__ import unicode_literals import six from django.apps import apps -from django.db.models.signals import post_save, post_delete +from django.db.models.signals import post_save, post_delete, pre_save, pre_delete from ..channel import Group from ..auth import channel_session, channel_session_user +CREATE = 'create' +UPDATE = 'update' +DELETE = 'delete' + + class BindingMetaclass(type): """ Metaclass that tracks instantiations of its type. @@ -72,10 +77,22 @@ class Binding(object): """ Resolves models. """ + # Connect signals + for model in cls.get_registered_models(): + pre_save.connect(cls.pre_save_receiver, sender=model) + post_save.connect(cls.post_save_receiver, sender=model) + pre_delete.connect(cls.pre_delete_receiver, sender=model) + post_delete.connect(cls.post_delete_receiver, sender=model) + + @classmethod + def get_registered_models(cls): + """ + Resolves the class model attribute if it's a string and returns it. + """ # If model is None directly on the class, assume it's abstract. if cls.model is None: if "model" in cls.__dict__: - return + return [] else: raise ValueError("You must set the model attribute on Binding %r!" % cls) # If fields is not defined, raise an error @@ -88,26 +105,10 @@ class Binding(object): cls.model._meta.app_label.lower(), cls.model._meta.object_name.lower(), ) - # Connect signals - post_save.connect(cls.save_receiver, sender=cls.model) - post_delete.connect(cls.delete_receiver, sender=cls.model) + return [cls.model] # Outbound binding - @classmethod - def save_receiver(cls, instance, created, **kwargs): - """ - Entry point for triggering the binding from save signals. - """ - cls.trigger_outbound(instance, "create" if created else "update") - - @classmethod - def delete_receiver(cls, instance, **kwargs): - """ - Entry point for triggering the binding from delete signals. - """ - cls.trigger_outbound(instance, "delete") - @classmethod def encode(cls, stream, payload): """ @@ -116,20 +117,70 @@ class Binding(object): raise NotImplementedError() @classmethod - def trigger_outbound(cls, instance, action): + def pre_save_receiver(cls, instance, **kwargs): + cls.pre_change_receiver(instance, CREATE if instance.pk is None else UPDATE) + + @classmethod + def post_save_receiver(cls, instance, created, **kwargs): + cls.post_change_receiver(instance, CREATE if created else UPDATE) + + @classmethod + def pre_delete_receiver(cls, instance, **kwargs): + cls.pre_change_receiver(instance, DELETE) + + @classmethod + def post_delete_receiver(cls, instance, **kwargs): + cls.post_change_receiver(instance, DELETE) + + @classmethod + def pre_change_receiver(cls, instance, action): + """ + Entry point for triggering the binding from save signals. + """ + if action == CREATE: + group_names = set() + else: + group_names = set(cls.group_names(instance)) + + if not hasattr(instance, '_binding_group_names'): + instance._binding_group_names = {} + instance._binding_group_names[cls] = group_names + + @classmethod + def post_change_receiver(cls, instance, action): """ Triggers the binding to possibly send to its group. """ + old_group_names = instance._binding_group_names[cls] + if action == DELETE: + new_group_names = set() + else: + new_group_names = set(cls.group_names(instance)) + + # if post delete, new_group_names should be [] self = cls() self.instance = instance - # Check to see if we're covered + + # Django DDP had used the ordering of DELETE, UPDATE then CREATE for good reasons. + self.send_messages(instance, old_group_names - new_group_names, DELETE) + self.send_messages(instance, old_group_names & new_group_names, UPDATE) + self.send_messages(instance, new_group_names - old_group_names, CREATE) + + def send_messages(self, instance, group_names, action): + """ + Serializes the instance and sends it to all provided group names. + """ + if not group_names: + return # no need to serialize, bail. payload = self.serialize(instance, action) - if payload != {}: - assert self.stream is not None - message = cls.encode(self.stream, payload) - for group_name in self.group_names(instance, action): - group = Group(group_name) - group.send(message) + if payload == {}: + return # nothing to send, bail. + + assert self.stream is not None + message = self.encode(self.stream, payload) + for group_name in group_names: + group = Group(group_name) + group.send(message) def group_names(self, instance, action): """ diff --git a/channels/tests/test_binding.py b/channels/tests/test_binding.py index dd85208..dca1092 100644 --- a/channels/tests/test_binding.py +++ b/channels/tests/test_binding.py @@ -17,7 +17,8 @@ class TestsBinding(ChannelTestCase): stream = 'test' fields = ['username', 'email', 'password', 'last_name'] - def group_names(self, instance, action): + @classmethod + def group_names(cls, instance): return ["users"] def has_permission(self, user, action, pk): @@ -58,7 +59,8 @@ class TestsBinding(ChannelTestCase): stream = 'test' fields = ['__all__'] - def group_names(self, instance, action): + @classmethod + def group_names(cls, instance): return ["users2"] def has_permission(self, user, action, pk): @@ -102,7 +104,8 @@ class TestsBinding(ChannelTestCase): stream = 'test' fields = ['username'] - def group_names(self, instance, action): + @classmethod + def group_names(cls, instance): return ["users3"] def has_permission(self, user, action, pk):