From ac7ef9e5c40e5026625b0b68c48e440dfb67f261 Mon Sep 17 00:00:00 2001 From: Sunil Mohan Adapa Date: Sun, 10 Mar 2024 11:38:20 -0700 Subject: [PATCH] actions: Move most of the privileged action code to main directory Tests - Run unit tests. Signed-off-by: Sunil Mohan Adapa Reviewed-by: James Valleroy --- actions/actions | 207 +-------------------------- plinth/actions.py | 196 ++++++++++++++++++++++++- plinth/tests/test_actions_actions.py | 19 ++- 3 files changed, 206 insertions(+), 216 deletions(-) diff --git a/actions/actions b/actions/actions index 7b7e5459b..8761c73eb 100755 --- a/actions/actions +++ b/actions/actions @@ -1,210 +1,7 @@ #!/usr/bin/python3 # SPDX-License-Identifier: AGPL-3.0-or-later -import argparse -import importlib -import inspect -import json -import logging -import os -import sys -import traceback -import types -import typing - -import plinth.log -from plinth import cfg, module_loader - -EXIT_SYNTAX = 10 -EXIT_PERM = 20 - -logger = logging.getLogger(__name__) - - -def main(): - """Parse arguments.""" - plinth.log.action_init() - - parser = argparse.ArgumentParser() - parser.add_argument('module', help='Module to trigger action in') - parser.add_argument('action', help='Action to trigger in module') - parser.add_argument('--write-fd', type=int, default=1, - help='File descriptor to write output to') - parser.add_argument('--no-args', default=False, action='store_true', - help='Do not read arguments from stdin') - args = parser.parse_args() - - try: - try: - arguments = {'args': [], 'kwargs': {}} - if not args.no_args: - input_ = sys.stdin.read() - if input_: - arguments = json.loads(input_) - except json.JSONDecodeError as exception: - raise SyntaxError('Arguments on stdin not JSON.') from exception - - return_value = _call(args.module, args.action, arguments) - with os.fdopen(args.write_fd, 'w') as write_file_handle: - write_file_handle.write(json.dumps(return_value)) - except PermissionError as exception: - logger.error(exception.args[0]) - sys.exit(EXIT_PERM) - except SyntaxError as exception: - logger.error(exception.args[0]) - sys.exit(EXIT_SYNTAX) - except TypeError as exception: - logger.error(exception.args[0]) - sys.exit(EXIT_SYNTAX) - except Exception as exception: - logger.exception(exception) - sys.exit(1) - - -def _call(module_name, action_name, arguments): - """Import the module and run action as superuser""" - if '.' in module_name: - raise SyntaxError('Invalid module name') - - cfg.read() - if module_name == 'plinth': - import_path = 'plinth' - else: - import_path = module_loader.get_module_import_path(module_name) - - try: - module = importlib.import_module(import_path + '.privileged') - except ModuleNotFoundError as exception: - raise SyntaxError('Specified module not found') from exception - - try: - action = getattr(module, action_name) - except AttributeError as exception: - raise SyntaxError('Specified action not found') from exception - - if not getattr(action, '_privileged', None): - raise SyntaxError('Specified action is not privileged action') - - func = getattr(action, '__wrapped__') - - _assert_valid_arguments(func, arguments) - - try: - return_values = func(*arguments['args'], **arguments['kwargs']) - return_value = {'result': 'success', 'return': return_values} - except Exception as exception: - logger.exception('Error executing action: %s', exception) - return_value = { - 'result': 'exception', - 'exception': { - 'module': type(exception).__module__, - 'name': type(exception).__name__, - 'args': exception.args, - 'traceback': traceback.format_tb(exception.__traceback__) - } - } - - return return_value - - -def _assert_valid_arguments(func, arguments): - """Check the names, types and completeness of the arguments passed.""" - # Check if arguments match types - if not isinstance(arguments, dict): - raise SyntaxError('Invalid arguments format') - - if 'args' not in arguments or 'kwargs' not in arguments: - raise SyntaxError('Invalid arguments format') - - args = arguments['args'] - kwargs = arguments['kwargs'] - if not isinstance(args, list) or not isinstance(kwargs, dict): - raise SyntaxError('Invalid arguments format') - - argspec = inspect.getfullargspec(func) - if len(args) + len(kwargs) > len(argspec.args): - raise SyntaxError('Too many arguments') - - no_defaults = len(argspec.args) - if argspec.defaults: - no_defaults -= len(argspec.defaults) - - for key in argspec.args[len(args):no_defaults]: - if key not in kwargs: - raise SyntaxError(f'Argument not provided: {key}') - - for key, value in kwargs.items(): - if key not in argspec.args: - raise SyntaxError(f'Unknown argument: {key}') - - if argspec.args.index(key) < len(args): - raise SyntaxError(f'Duplicate argument: {key}') - - _assert_valid_type(f'arg {key}', value, argspec.annotations[key]) - - for index, arg in enumerate(args): - annotation = argspec.annotations[argspec.args[index]] - _assert_valid_type(f'arg #{index}', arg, annotation) - - -def _assert_valid_type(arg_name, value, annotation): - """Assert that the type of argument value matches the annotation.""" - if annotation == typing.Any: - return - - NoneType = type(None) - if annotation == NoneType: - if value is not None: - raise TypeError('Expected None for {arg_name}') - - return - - basic_types = {bool, int, str, float} - if annotation in basic_types: - if not isinstance(value, annotation): - raise TypeError( - f'Expected type {annotation.__name__} for {arg_name}') - - return - - # 'int | str' or 'typing.Union[int, str]' - if (isinstance(annotation, types.UnionType) - or getattr(annotation, '__origin__', None) == typing.Union): - for arg in annotation.__args__: - try: - _assert_valid_type(arg_name, value, arg) - return - except TypeError: - pass - - raise TypeError(f'Expected one of unioned types for {arg_name}') - - # 'list[int]' or 'typing.List[int]' - if getattr(annotation, '__origin__', None) == list: - if not isinstance(value, list): - raise TypeError(f'Expected type list for {arg_name}') - - for index, inner_item in enumerate(value): - _assert_valid_type(f'{arg_name}[{index}]', inner_item, - annotation.__args__[0]) - - return - - # 'list[dict]' or 'typing.List[dict]' - if getattr(annotation, '__origin__', None) == dict: - if not isinstance(value, dict): - raise TypeError(f'Expected type dict for {arg_name}') - - for inner_key, inner_value in value.items(): - _assert_valid_type(f'{arg_name}[{inner_key}]', inner_key, - annotation.__args__[0]) - _assert_valid_type(f'{arg_name}[{inner_value}]', inner_value, - annotation.__args__[1]) - - return - - raise TypeError('Unsupported annotation type') - +from plinth.actions import privileged_main if __name__ == '__main__': - main() + privileged_main() diff --git a/plinth/actions.py b/plinth/actions.py index 187a0bd0b..8f9f1e802 100644 --- a/plinth/actions.py +++ b/plinth/actions.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: AGPL-3.0-or-later """Framework to run specified actions with elevated privileges.""" +import argparse import functools import importlib import inspect @@ -8,9 +9,16 @@ import json import logging import os import subprocess +import sys import threading +import traceback +import types +import typing -from plinth import cfg +from plinth import cfg, log, module_loader + +EXIT_SYNTAX = 10 +EXIT_PERM = 20 logger = logging.getLogger(__name__) @@ -235,3 +243,189 @@ def _log_action(module_name, action_name, run_as_user, run_in_background): prompt = f'({run_as_user})$' if run_as_user else '#' suffix = '&' if run_in_background else '' logger.info('%s %s..%s(…) %s', prompt, module_name, action_name, suffix) + + +def privileged_main(): + """Parse arguments for the program spawned as a privileged action.""" + log.action_init() + + parser = argparse.ArgumentParser() + parser.add_argument('module', help='Module to trigger action in') + parser.add_argument('action', help='Action to trigger in module') + parser.add_argument('--write-fd', type=int, default=1, + help='File descriptor to write output to') + parser.add_argument('--no-args', default=False, action='store_true', + help='Do not read arguments from stdin') + args = parser.parse_args() + + try: + try: + arguments = {'args': [], 'kwargs': {}} + if not args.no_args: + input_ = sys.stdin.read() + if input_: + arguments = json.loads(input_) + except json.JSONDecodeError as exception: + raise SyntaxError('Arguments on stdin not JSON.') from exception + + return_value = _privileged_call(args.module, args.action, arguments) + with os.fdopen(args.write_fd, 'w') as write_file_handle: + write_file_handle.write(json.dumps(return_value)) + except PermissionError as exception: + logger.error(exception.args[0]) + sys.exit(EXIT_PERM) + except SyntaxError as exception: + logger.error(exception.args[0]) + sys.exit(EXIT_SYNTAX) + except TypeError as exception: + logger.error(exception.args[0]) + sys.exit(EXIT_SYNTAX) + except Exception as exception: + logger.exception(exception) + sys.exit(1) + + +def _privileged_call(module_name, action_name, arguments): + """Import the module and run action as superuser""" + if '.' in module_name: + raise SyntaxError('Invalid module name') + + cfg.read() + if module_name == 'plinth': + import_path = 'plinth' + else: + import_path = module_loader.get_module_import_path(module_name) + + try: + module = importlib.import_module(import_path + '.privileged') + except ModuleNotFoundError as exception: + raise SyntaxError('Specified module not found') from exception + + try: + action = getattr(module, action_name) + except AttributeError as exception: + raise SyntaxError('Specified action not found') from exception + + if not getattr(action, '_privileged', None): + raise SyntaxError('Specified action is not privileged action') + + func = getattr(action, '__wrapped__') + + _privileged_assert_valid_arguments(func, arguments) + + try: + return_values = func(*arguments['args'], **arguments['kwargs']) + return_value = {'result': 'success', 'return': return_values} + except Exception as exception: + logger.exception('Error executing action: %s', exception) + return_value = { + 'result': 'exception', + 'exception': { + 'module': type(exception).__module__, + 'name': type(exception).__name__, + 'args': exception.args, + 'traceback': traceback.format_tb(exception.__traceback__) + } + } + + return return_value + + +def _privileged_assert_valid_arguments(func, arguments): + """Check the names, types and completeness of the arguments passed.""" + # Check if arguments match types + if not isinstance(arguments, dict): + raise SyntaxError('Invalid arguments format') + + if 'args' not in arguments or 'kwargs' not in arguments: + raise SyntaxError('Invalid arguments format') + + args = arguments['args'] + kwargs = arguments['kwargs'] + if not isinstance(args, list) or not isinstance(kwargs, dict): + raise SyntaxError('Invalid arguments format') + + argspec = inspect.getfullargspec(func) + if len(args) + len(kwargs) > len(argspec.args): + raise SyntaxError('Too many arguments') + + no_defaults = len(argspec.args) + if argspec.defaults: + no_defaults -= len(argspec.defaults) + + for key in argspec.args[len(args):no_defaults]: + if key not in kwargs: + raise SyntaxError(f'Argument not provided: {key}') + + for key, value in kwargs.items(): + if key not in argspec.args: + raise SyntaxError(f'Unknown argument: {key}') + + if argspec.args.index(key) < len(args): + raise SyntaxError(f'Duplicate argument: {key}') + + _privileged_assert_valid_type(f'arg {key}', value, + argspec.annotations[key]) + + for index, arg in enumerate(args): + annotation = argspec.annotations[argspec.args[index]] + _privileged_assert_valid_type(f'arg #{index}', arg, annotation) + + +def _privileged_assert_valid_type(arg_name, value, annotation): + """Assert that the type of argument value matches the annotation.""" + if annotation == typing.Any: + return + + NoneType = type(None) + if annotation == NoneType: + if value is not None: + raise TypeError('Expected None for {arg_name}') + + return + + basic_types = {bool, int, str, float} + if annotation in basic_types: + if not isinstance(value, annotation): + raise TypeError( + f'Expected type {annotation.__name__} for {arg_name}') + + return + + # 'int | str' or 'typing.Union[int, str]' + if (isinstance(annotation, types.UnionType) + or getattr(annotation, '__origin__', None) == typing.Union): + for arg in annotation.__args__: + try: + _privileged_assert_valid_type(arg_name, value, arg) + return + except TypeError: + pass + + raise TypeError(f'Expected one of unioned types for {arg_name}') + + # 'list[int]' or 'typing.List[int]' + if getattr(annotation, '__origin__', None) == list: + if not isinstance(value, list): + raise TypeError(f'Expected type list for {arg_name}') + + for index, inner_item in enumerate(value): + _privileged_assert_valid_type(f'{arg_name}[{index}]', inner_item, + annotation.__args__[0]) + + return + + # 'list[dict]' or 'typing.List[dict]' + if getattr(annotation, '__origin__', None) == dict: + if not isinstance(value, dict): + raise TypeError(f'Expected type dict for {arg_name}') + + for inner_key, inner_value in value.items(): + _privileged_assert_valid_type(f'{arg_name}[{inner_key}]', + inner_key, annotation.__args__[0]) + _privileged_assert_valid_type(f'{arg_name}[{inner_value}]', + inner_value, annotation.__args__[1]) + + return + + raise TypeError('Unsupported annotation type') diff --git a/plinth/tests/test_actions_actions.py b/plinth/tests/test_actions_actions.py index 088fa6fa4..5ae5fa68a 100644 --- a/plinth/tests/test_actions_actions.py +++ b/plinth/tests/test_actions_actions.py @@ -9,7 +9,7 @@ from unittest.mock import patch import pytest -from plinth.actions import privileged +from plinth import actions actions_name = 'actions' @@ -17,10 +17,9 @@ actions_name = 'actions' @patch('importlib.import_module') @patch('plinth.module_loader.get_module_import_path') @patch('os.getuid') -def test_call_syntax_checks(getuid, get_module_import_path, import_module, - actions_module): +def test_call_syntax_checks(getuid, get_module_import_path, import_module): """Test that calling a method results in proper syntax checks.""" - call = actions_module._call + call = actions._privileged_call # Module name validation getuid.return_value = 0 @@ -53,7 +52,7 @@ def test_call_syntax_checks(getuid, get_module_import_path, import_module, call('test-module', 'func', {}) # Argument validation - @privileged + @actions.privileged def func(): return 'foo' @@ -66,7 +65,7 @@ def test_call_syntax_checks(getuid, get_module_import_path, import_module, assert return_value == {'result': 'success', 'return': 'foo'} # Exception call - @privileged + @actions.privileged def exception_func(): raise RuntimeError('foo exception') @@ -81,9 +80,9 @@ def test_call_syntax_checks(getuid, get_module_import_path, import_module, assert isinstance(line, str) -def test_assert_valid_arguments(actions_module): +def test_assert_valid_arguments(): """Test that checking valid arguments works.""" - assert_valid = actions_module._assert_valid_arguments + assert_valid = actions._privileged_assert_valid_arguments values = [ None, [], 10, {}, { @@ -139,9 +138,9 @@ def test_assert_valid_arguments(actions_module): assert_valid(func, {'args': [1, '2'], 'kwargs': {'c': '3'}}) -def test_assert_valid_type(actions_module): +def test_assert_valid_type(): """Test that type validation works as expected.""" - assert_valid = actions_module._assert_valid_type + assert_valid = actions._privileged_assert_valid_type assert_valid(None, None, typing.Any)