diff --git a/channels/binding/base.py b/channels/binding/base.py index a1f1977..5788988 100644 --- a/channels/binding/base.py +++ b/channels/binding/base.py @@ -67,6 +67,7 @@ class Binding(object): # if you want to really send all fields, use fields = ['__all__'] fields = None + exclude = None # Decorators channel_session_user = True @@ -95,9 +96,9 @@ class Binding(object): return [] else: raise ValueError("You must set the model attribute on Binding %r!" % cls) - # If fields is not defined, raise an error - if cls.fields is None: - raise ValueError("You must set the fields attribute on Binding %r!" % cls) + # If neither fields nor exclude are not defined, raise an error + if cls.fields is None and cls.exclude is None: + raise ValueError("You must set the fields or exclude attribute on Binding %r!" % cls) # Optionally resolve model strings if isinstance(cls.model, six.string_types): cls.model = apps.get_model(cls.model) diff --git a/channels/binding/websockets.py b/channels/binding/websockets.py index e458b6b..4238b2d 100644 --- a/channels/binding/websockets.py +++ b/channels/binding/websockets.py @@ -55,10 +55,13 @@ class WebsocketBinding(Binding): """ Serializes model data into JSON-compatible types. """ - if list(self.fields) == ['__all__']: - fields = None + if self.fields is not None: + if list(self.fields) == ['__all__']: + fields = None + else: + fields = self.fields else: - fields = self.fields + fields = [f.name for f in instance._meta.get_fields() if f.name not in self.exclude] data = serializers.serialize('json', [instance], fields=fields) return json.loads(data)[0]['fields'] @@ -109,9 +112,15 @@ class WebsocketBinding(Binding): def update(self, pk, data): instance = self.model.objects.get(pk=pk) hydrated = self._hydrate(pk, data) - for name in data.keys(): - if name in self.fields or self.fields == ['__all__']: - setattr(instance, name, getattr(hydrated.object, name)) + + if self.fields is not None: + for name in data.keys(): + if name in self.fields or self.fields == ['__all__']: + setattr(instance, name, getattr(hydrated.object, name)) + else: + for name in data.keys(): + if name not in self.exclude: + setattr(instance, name, getattr(hydrated.object, name)) instance.save() diff --git a/channels/tests/test_binding.py b/channels/tests/test_binding.py index c3dca66..63fa4e3 100644 --- a/channels/tests/test_binding.py +++ b/channels/tests/test_binding.py @@ -58,6 +58,65 @@ class TestsBinding(ChannelTestCase): received = client.receive() self.assertIsNone(received) + def test_trigger_outbound_create_exclude(self): + class TestBinding(WebsocketBinding): + model = User + stream = 'test' + exclude = ['first_name', 'last_name'] + + @classmethod + def group_names(cls, instance, action): + return ["users_exclude"] + + def has_permission(self, user, action, pk): + return True + + with apply_routes([route('test', TestBinding.consumer)]): + client = HttpClient() + client.join_group('users_exclude') + + user = User.objects.create(username='test', email='test@test.com') + consumer_finished.send(sender=None) + consumer_finished.send(sender=None) + received = client.receive() + + self.assertTrue('payload' in received) + self.assertTrue('action' in received['payload']) + self.assertTrue('data' in received['payload']) + self.assertTrue('username' in received['payload']['data']) + self.assertTrue('email' in received['payload']['data']) + self.assertTrue('password' in received['payload']['data']) + self.assertTrue('model' in received['payload']) + self.assertTrue('pk' in received['payload']) + + self.assertFalse('last_name' in received['payload']['data']) + self.assertFalse('first_name' in received['payload']['data']) + + self.assertEqual(received['payload']['action'], 'create') + self.assertEqual(received['payload']['model'], 'auth.user') + self.assertEqual(received['payload']['pk'], user.pk) + + self.assertEqual(received['payload']['data']['email'], 'test@test.com') + self.assertEqual(received['payload']['data']['username'], 'test') + self.assertEqual(received['payload']['data']['password'], '') + + received = client.receive() + self.assertIsNone(received) + + def test_omit_fields_and_exclude(self): + def _declare_class(): + class TestBinding(WebsocketBinding): + model = User + stream = 'test' + + @classmethod + def group_names(cls, instance, action): + return ["users_omit"] + + def has_permission(self, user, action, pk): + return True + self.assertRaises(ValueError, _declare_class) + def test_trigger_outbound_update(self): class TestBinding(WebsocketBinding): model = User diff --git a/docs/binding.rst b/docs/binding.rst index c294d5f..e8ce3a8 100644 --- a/docs/binding.rst +++ b/docs/binding.rst @@ -81,7 +81,8 @@ always provide: * ``fields`` is a whitelist of fields to return in the serialized request. Channels does not default to all fields for security concerns; if you want - this, set it to the value ``["__all__"]``. + this, set it to the value ``["__all__"]``. As an alternative, ``exclude`` + acts as a blacklist of fields. * ``group_names`` returns a list of groups to send outbound updates to based on the model and action. For example, you could dispatch posts on different