actions: Move most of the privileged action code to main directory

Tests

- Run unit tests.

Signed-off-by: Sunil Mohan Adapa <sunil@medhas.org>
Reviewed-by: James Valleroy <jvalleroy@mailbox.org>
This commit is contained in:
Sunil Mohan Adapa 2024-03-10 11:38:20 -07:00 committed by James Valleroy
parent 88c12df7e0
commit ac7ef9e5c4
No known key found for this signature in database
GPG Key ID: 77C0C75E7B650808
3 changed files with 206 additions and 216 deletions

View File

@ -1,210 +1,7 @@
#!/usr/bin/python3 #!/usr/bin/python3
# SPDX-License-Identifier: AGPL-3.0-or-later # SPDX-License-Identifier: AGPL-3.0-or-later
import argparse from plinth.actions import privileged_main
import importlib
import inspect
import json
import logging
import os
import sys
import traceback
import types
import typing
import plinth.log
from plinth import cfg, module_loader
EXIT_SYNTAX = 10
EXIT_PERM = 20
logger = logging.getLogger(__name__)
def main():
"""Parse arguments."""
plinth.log.action_init()
parser = argparse.ArgumentParser()
parser.add_argument('module', help='Module to trigger action in')
parser.add_argument('action', help='Action to trigger in module')
parser.add_argument('--write-fd', type=int, default=1,
help='File descriptor to write output to')
parser.add_argument('--no-args', default=False, action='store_true',
help='Do not read arguments from stdin')
args = parser.parse_args()
try:
try:
arguments = {'args': [], 'kwargs': {}}
if not args.no_args:
input_ = sys.stdin.read()
if input_:
arguments = json.loads(input_)
except json.JSONDecodeError as exception:
raise SyntaxError('Arguments on stdin not JSON.') from exception
return_value = _call(args.module, args.action, arguments)
with os.fdopen(args.write_fd, 'w') as write_file_handle:
write_file_handle.write(json.dumps(return_value))
except PermissionError as exception:
logger.error(exception.args[0])
sys.exit(EXIT_PERM)
except SyntaxError as exception:
logger.error(exception.args[0])
sys.exit(EXIT_SYNTAX)
except TypeError as exception:
logger.error(exception.args[0])
sys.exit(EXIT_SYNTAX)
except Exception as exception:
logger.exception(exception)
sys.exit(1)
def _call(module_name, action_name, arguments):
"""Import the module and run action as superuser"""
if '.' in module_name:
raise SyntaxError('Invalid module name')
cfg.read()
if module_name == 'plinth':
import_path = 'plinth'
else:
import_path = module_loader.get_module_import_path(module_name)
try:
module = importlib.import_module(import_path + '.privileged')
except ModuleNotFoundError as exception:
raise SyntaxError('Specified module not found') from exception
try:
action = getattr(module, action_name)
except AttributeError as exception:
raise SyntaxError('Specified action not found') from exception
if not getattr(action, '_privileged', None):
raise SyntaxError('Specified action is not privileged action')
func = getattr(action, '__wrapped__')
_assert_valid_arguments(func, arguments)
try:
return_values = func(*arguments['args'], **arguments['kwargs'])
return_value = {'result': 'success', 'return': return_values}
except Exception as exception:
logger.exception('Error executing action: %s', exception)
return_value = {
'result': 'exception',
'exception': {
'module': type(exception).__module__,
'name': type(exception).__name__,
'args': exception.args,
'traceback': traceback.format_tb(exception.__traceback__)
}
}
return return_value
def _assert_valid_arguments(func, arguments):
"""Check the names, types and completeness of the arguments passed."""
# Check if arguments match types
if not isinstance(arguments, dict):
raise SyntaxError('Invalid arguments format')
if 'args' not in arguments or 'kwargs' not in arguments:
raise SyntaxError('Invalid arguments format')
args = arguments['args']
kwargs = arguments['kwargs']
if not isinstance(args, list) or not isinstance(kwargs, dict):
raise SyntaxError('Invalid arguments format')
argspec = inspect.getfullargspec(func)
if len(args) + len(kwargs) > len(argspec.args):
raise SyntaxError('Too many arguments')
no_defaults = len(argspec.args)
if argspec.defaults:
no_defaults -= len(argspec.defaults)
for key in argspec.args[len(args):no_defaults]:
if key not in kwargs:
raise SyntaxError(f'Argument not provided: {key}')
for key, value in kwargs.items():
if key not in argspec.args:
raise SyntaxError(f'Unknown argument: {key}')
if argspec.args.index(key) < len(args):
raise SyntaxError(f'Duplicate argument: {key}')
_assert_valid_type(f'arg {key}', value, argspec.annotations[key])
for index, arg in enumerate(args):
annotation = argspec.annotations[argspec.args[index]]
_assert_valid_type(f'arg #{index}', arg, annotation)
def _assert_valid_type(arg_name, value, annotation):
"""Assert that the type of argument value matches the annotation."""
if annotation == typing.Any:
return
NoneType = type(None)
if annotation == NoneType:
if value is not None:
raise TypeError('Expected None for {arg_name}')
return
basic_types = {bool, int, str, float}
if annotation in basic_types:
if not isinstance(value, annotation):
raise TypeError(
f'Expected type {annotation.__name__} for {arg_name}')
return
# 'int | str' or 'typing.Union[int, str]'
if (isinstance(annotation, types.UnionType)
or getattr(annotation, '__origin__', None) == typing.Union):
for arg in annotation.__args__:
try:
_assert_valid_type(arg_name, value, arg)
return
except TypeError:
pass
raise TypeError(f'Expected one of unioned types for {arg_name}')
# 'list[int]' or 'typing.List[int]'
if getattr(annotation, '__origin__', None) == list:
if not isinstance(value, list):
raise TypeError(f'Expected type list for {arg_name}')
for index, inner_item in enumerate(value):
_assert_valid_type(f'{arg_name}[{index}]', inner_item,
annotation.__args__[0])
return
# 'list[dict]' or 'typing.List[dict]'
if getattr(annotation, '__origin__', None) == dict:
if not isinstance(value, dict):
raise TypeError(f'Expected type dict for {arg_name}')
for inner_key, inner_value in value.items():
_assert_valid_type(f'{arg_name}[{inner_key}]', inner_key,
annotation.__args__[0])
_assert_valid_type(f'{arg_name}[{inner_value}]', inner_value,
annotation.__args__[1])
return
raise TypeError('Unsupported annotation type')
if __name__ == '__main__': if __name__ == '__main__':
main() privileged_main()

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: AGPL-3.0-or-later # SPDX-License-Identifier: AGPL-3.0-or-later
"""Framework to run specified actions with elevated privileges.""" """Framework to run specified actions with elevated privileges."""
import argparse
import functools import functools
import importlib import importlib
import inspect import inspect
@ -8,9 +9,16 @@ import json
import logging import logging
import os import os
import subprocess import subprocess
import sys
import threading import threading
import traceback
import types
import typing
from plinth import cfg from plinth import cfg, log, module_loader
EXIT_SYNTAX = 10
EXIT_PERM = 20
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -235,3 +243,189 @@ def _log_action(module_name, action_name, run_as_user, run_in_background):
prompt = f'({run_as_user})$' if run_as_user else '#' prompt = f'({run_as_user})$' if run_as_user else '#'
suffix = '&' if run_in_background else '' suffix = '&' if run_in_background else ''
logger.info('%s %s..%s(…) %s', prompt, module_name, action_name, suffix) logger.info('%s %s..%s(…) %s', prompt, module_name, action_name, suffix)
def privileged_main():
"""Parse arguments for the program spawned as a privileged action."""
log.action_init()
parser = argparse.ArgumentParser()
parser.add_argument('module', help='Module to trigger action in')
parser.add_argument('action', help='Action to trigger in module')
parser.add_argument('--write-fd', type=int, default=1,
help='File descriptor to write output to')
parser.add_argument('--no-args', default=False, action='store_true',
help='Do not read arguments from stdin')
args = parser.parse_args()
try:
try:
arguments = {'args': [], 'kwargs': {}}
if not args.no_args:
input_ = sys.stdin.read()
if input_:
arguments = json.loads(input_)
except json.JSONDecodeError as exception:
raise SyntaxError('Arguments on stdin not JSON.') from exception
return_value = _privileged_call(args.module, args.action, arguments)
with os.fdopen(args.write_fd, 'w') as write_file_handle:
write_file_handle.write(json.dumps(return_value))
except PermissionError as exception:
logger.error(exception.args[0])
sys.exit(EXIT_PERM)
except SyntaxError as exception:
logger.error(exception.args[0])
sys.exit(EXIT_SYNTAX)
except TypeError as exception:
logger.error(exception.args[0])
sys.exit(EXIT_SYNTAX)
except Exception as exception:
logger.exception(exception)
sys.exit(1)
def _privileged_call(module_name, action_name, arguments):
"""Import the module and run action as superuser"""
if '.' in module_name:
raise SyntaxError('Invalid module name')
cfg.read()
if module_name == 'plinth':
import_path = 'plinth'
else:
import_path = module_loader.get_module_import_path(module_name)
try:
module = importlib.import_module(import_path + '.privileged')
except ModuleNotFoundError as exception:
raise SyntaxError('Specified module not found') from exception
try:
action = getattr(module, action_name)
except AttributeError as exception:
raise SyntaxError('Specified action not found') from exception
if not getattr(action, '_privileged', None):
raise SyntaxError('Specified action is not privileged action')
func = getattr(action, '__wrapped__')
_privileged_assert_valid_arguments(func, arguments)
try:
return_values = func(*arguments['args'], **arguments['kwargs'])
return_value = {'result': 'success', 'return': return_values}
except Exception as exception:
logger.exception('Error executing action: %s', exception)
return_value = {
'result': 'exception',
'exception': {
'module': type(exception).__module__,
'name': type(exception).__name__,
'args': exception.args,
'traceback': traceback.format_tb(exception.__traceback__)
}
}
return return_value
def _privileged_assert_valid_arguments(func, arguments):
"""Check the names, types and completeness of the arguments passed."""
# Check if arguments match types
if not isinstance(arguments, dict):
raise SyntaxError('Invalid arguments format')
if 'args' not in arguments or 'kwargs' not in arguments:
raise SyntaxError('Invalid arguments format')
args = arguments['args']
kwargs = arguments['kwargs']
if not isinstance(args, list) or not isinstance(kwargs, dict):
raise SyntaxError('Invalid arguments format')
argspec = inspect.getfullargspec(func)
if len(args) + len(kwargs) > len(argspec.args):
raise SyntaxError('Too many arguments')
no_defaults = len(argspec.args)
if argspec.defaults:
no_defaults -= len(argspec.defaults)
for key in argspec.args[len(args):no_defaults]:
if key not in kwargs:
raise SyntaxError(f'Argument not provided: {key}')
for key, value in kwargs.items():
if key not in argspec.args:
raise SyntaxError(f'Unknown argument: {key}')
if argspec.args.index(key) < len(args):
raise SyntaxError(f'Duplicate argument: {key}')
_privileged_assert_valid_type(f'arg {key}', value,
argspec.annotations[key])
for index, arg in enumerate(args):
annotation = argspec.annotations[argspec.args[index]]
_privileged_assert_valid_type(f'arg #{index}', arg, annotation)
def _privileged_assert_valid_type(arg_name, value, annotation):
"""Assert that the type of argument value matches the annotation."""
if annotation == typing.Any:
return
NoneType = type(None)
if annotation == NoneType:
if value is not None:
raise TypeError('Expected None for {arg_name}')
return
basic_types = {bool, int, str, float}
if annotation in basic_types:
if not isinstance(value, annotation):
raise TypeError(
f'Expected type {annotation.__name__} for {arg_name}')
return
# 'int | str' or 'typing.Union[int, str]'
if (isinstance(annotation, types.UnionType)
or getattr(annotation, '__origin__', None) == typing.Union):
for arg in annotation.__args__:
try:
_privileged_assert_valid_type(arg_name, value, arg)
return
except TypeError:
pass
raise TypeError(f'Expected one of unioned types for {arg_name}')
# 'list[int]' or 'typing.List[int]'
if getattr(annotation, '__origin__', None) == list:
if not isinstance(value, list):
raise TypeError(f'Expected type list for {arg_name}')
for index, inner_item in enumerate(value):
_privileged_assert_valid_type(f'{arg_name}[{index}]', inner_item,
annotation.__args__[0])
return
# 'list[dict]' or 'typing.List[dict]'
if getattr(annotation, '__origin__', None) == dict:
if not isinstance(value, dict):
raise TypeError(f'Expected type dict for {arg_name}')
for inner_key, inner_value in value.items():
_privileged_assert_valid_type(f'{arg_name}[{inner_key}]',
inner_key, annotation.__args__[0])
_privileged_assert_valid_type(f'{arg_name}[{inner_value}]',
inner_value, annotation.__args__[1])
return
raise TypeError('Unsupported annotation type')

View File

@ -9,7 +9,7 @@ from unittest.mock import patch
import pytest import pytest
from plinth.actions import privileged from plinth import actions
actions_name = 'actions' actions_name = 'actions'
@ -17,10 +17,9 @@ actions_name = 'actions'
@patch('importlib.import_module') @patch('importlib.import_module')
@patch('plinth.module_loader.get_module_import_path') @patch('plinth.module_loader.get_module_import_path')
@patch('os.getuid') @patch('os.getuid')
def test_call_syntax_checks(getuid, get_module_import_path, import_module, def test_call_syntax_checks(getuid, get_module_import_path, import_module):
actions_module):
"""Test that calling a method results in proper syntax checks.""" """Test that calling a method results in proper syntax checks."""
call = actions_module._call call = actions._privileged_call
# Module name validation # Module name validation
getuid.return_value = 0 getuid.return_value = 0
@ -53,7 +52,7 @@ def test_call_syntax_checks(getuid, get_module_import_path, import_module,
call('test-module', 'func', {}) call('test-module', 'func', {})
# Argument validation # Argument validation
@privileged @actions.privileged
def func(): def func():
return 'foo' return 'foo'
@ -66,7 +65,7 @@ def test_call_syntax_checks(getuid, get_module_import_path, import_module,
assert return_value == {'result': 'success', 'return': 'foo'} assert return_value == {'result': 'success', 'return': 'foo'}
# Exception call # Exception call
@privileged @actions.privileged
def exception_func(): def exception_func():
raise RuntimeError('foo exception') raise RuntimeError('foo exception')
@ -81,9 +80,9 @@ def test_call_syntax_checks(getuid, get_module_import_path, import_module,
assert isinstance(line, str) assert isinstance(line, str)
def test_assert_valid_arguments(actions_module): def test_assert_valid_arguments():
"""Test that checking valid arguments works.""" """Test that checking valid arguments works."""
assert_valid = actions_module._assert_valid_arguments assert_valid = actions._privileged_assert_valid_arguments
values = [ values = [
None, [], 10, {}, { None, [], 10, {}, {
@ -139,9 +138,9 @@ def test_assert_valid_arguments(actions_module):
assert_valid(func, {'args': [1, '2'], 'kwargs': {'c': '3'}}) assert_valid(func, {'args': [1, '2'], 'kwargs': {'c': '3'}})
def test_assert_valid_type(actions_module): def test_assert_valid_type():
"""Test that type validation works as expected.""" """Test that type validation works as expected."""
assert_valid = actions_module._assert_valid_type assert_valid = actions._privileged_assert_valid_type
assert_valid(None, None, typing.Any) assert_valid(None, None, typing.Any)