diff --git a/plinth/modules/wireguard/forms.py b/plinth/modules/wireguard/forms.py index e81c954c1..f7bfd7d8c 100644 --- a/plinth/modules/wireguard/forms.py +++ b/plinth/modules/wireguard/forms.py @@ -18,15 +18,65 @@ Forms for wireguard module. """ +import base64 +import binascii + from django import forms +from django.core.exceptions import ValidationError +from django.core.validators import validate_ipv4_address from django.utils.translation import ugettext_lazy as _ +KEY_LENGTH = 32 + + +def validate_key(key): + """Validate a WireGuard public/private/pre-shared key.""" + valid = False + if isinstance(key, str): + key = key.encode() + + try: + decoded_key = base64.b64decode(key) + if len(decoded_key) == KEY_LENGTH and base64.b64encode( + decoded_key) == key: + valid = True + except binascii.Error: + pass + + if not valid: + raise ValidationError(_('Invalid key.')) + + +def validate_endpoint(endpoint): + """Validate an endpoint of the form: demo.wireguard.com:12912. + + Implemented similar to nm-utils.c::_parse_endpoint(). + + """ + valid = False + try: + destination, port = endpoint.rsplit(':', maxsplit=1) + port = int(port) + if 1 <= port < ((1 << 16) - 1) and destination: + valid = True + + if destination[0] == '[' and (destination[-1] != ']' + or len(destination) < 3): + valid = False + except ValueError: + pass + + if not valid: + raise ValidationError('Invalid endpoint.') + class AddClientForm(forms.Form): """Form to add client.""" public_key = forms.CharField( label=_('Public Key'), strip=True, - help_text=_('Public key of the peer.')) + help_text=_('Public key of the peer. Example: ' + 'MConEJFIg6+DFHg2J1nn9SNLOSE9KR0ysdPgmPjibEs= .'), + validators=[validate_key]) class AddServerForm(forms.Form): @@ -34,29 +84,37 @@ class AddServerForm(forms.Form): peer_endpoint = forms.CharField( label=_('Endpoint of the server'), strip=True, help_text=_('Domain name and port in the form "ip:port". Example: ' - 'demo.wireguard.com:12912 .')) + 'demo.wireguard.com:12912 .'), + validators=[validate_endpoint]) peer_public_key = forms.CharField( label=_('Public key of the server'), strip=True, help_text=_( - 'Provided by the server operator, a long string of characters.')) + 'Provided by the server operator, a long string of characters. ' + 'Example: MConEJFIg6+DFHg2J1nn9SNLOSE9KR0ysdPgmPjibEs= .'), + validators=[validate_key]) ip_address = forms.CharField( label=_('Client IP address provided by server'), strip=True, help_text=_('IP address assigned to this machine on the VPN after ' 'connecting to the endpoint. This value is usually ' - 'provided by the server operator. Example: 192.168.0.10.')) + 'provided by the server operator. Example: 192.168.0.10.'), + validators=[validate_ipv4_address]) private_key = forms.CharField( label=_('Private key of this machine'), strip=True, help_text=_( 'Optional. New public/private keys are generated if left blank. ' 'Public key can then be provided to the server. This is the ' 'recommended way. However, some server operators insist on ' - 'providing this.'), required=False) + 'providing this. Example: ' + 'MConEJFIg6+DFHg2J1nn9SNLOSE9KR0ysdPgmPjibEs= .'), required=False, + validators=[validate_key]) preshared_key = forms.CharField( label=_('Pre-shared key'), strip=True, required=False, help_text=_( 'Optional. A shared secret key provided by the server to add an ' - 'additional layer of security. Fill in only if provided.')) + 'additional layer of security. Fill in only if provided. Example: ' + 'MConEJFIg6+DFHg2J1nn9SNLOSE9KR0ysdPgmPjibEs=.'), + validators=[validate_key]) default_route = forms.BooleanField( label=_('Use this connection to send all outgoing traffic'), diff --git a/plinth/modules/wireguard/tests/__init__.py b/plinth/modules/wireguard/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plinth/modules/wireguard/tests/test_forms.py b/plinth/modules/wireguard/tests/test_forms.py new file mode 100644 index 000000000..acf3b0ca9 --- /dev/null +++ b/plinth/modules/wireguard/tests/test_forms.py @@ -0,0 +1,79 @@ +# +# This file is part of FreedomBox. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +""" +Tests for wireguard module forms. +""" + +import pytest +from django.core.exceptions import ValidationError + +from plinth.modules.wireguard.forms import validate_endpoint, validate_key + + +@pytest.mark.parametrize('key', [ + 'gKQhVGla4UtdqeY1dQ21G5lqrnX5NFcSEAqzM5iSdl0=', + 'uHWSYIjPnS9fYFhZ0mf22IkOMyrWXDlfpXs6ve4QGHk=', +]) +def test_validate_key_valid_patterns(key): + """Test that valid wireguard key patterns as accepted.""" + validate_key(key) + + +@pytest.mark.parametrize( + 'key', + [ + # Invalid padding + 'gKQhVGla4UtdqeY1dQ21G5lqrnX5NFcSEAqzM5iSdl0', + 'invalid-base64', + '', + 'aW52YWxpZC1sZW5ndGg=', # Incorrect length + ]) +def test_validate_key_invalid_patterns(key): + """Test that invalid wireguard key patterns are rejected.""" + with pytest.raises(ValidationError): + validate_key(key) + + +@pytest.mark.parametrize('endpoint', [ + '[1::2]:1234', + '1.2.3.4:1234', + 'example.com:1234', +]) +def test_validate_endpoint_valid_patterns(endpoint): + """Test that valid wireguard endpoint patterns are accepted.""" + validate_endpoint(endpoint) + + +@pytest.mark.parametrize( + 'endpoint', + [ + '', + # Invalid port + '1.2.3.4', + '1.2.3.4:', + '1.2.3.4:0', + '1.2.3.4:65536', + '1.2.3.4:1234invalid', + '1.2.3.4:invalid', + # Invalid IPv6 + '[]:1234', + '[:1234', + ]) +def test_validate_endpoint_invalid_patterns(endpoint): + """Test that invalid wireguard endpoint patterns are rejected.""" + with pytest.raises(ValidationError): + validate_endpoint(endpoint)