actions: Move most of the privileged action code to main directory

Tests

- Run unit tests.

Signed-off-by: Sunil Mohan Adapa <sunil@medhas.org>
Reviewed-by: James Valleroy <jvalleroy@mailbox.org>
This commit is contained in:
Sunil Mohan Adapa 2024-03-10 11:38:20 -07:00 committed by James Valleroy
parent 88c12df7e0
commit ac7ef9e5c4
No known key found for this signature in database
GPG Key ID: 77C0C75E7B650808
3 changed files with 206 additions and 216 deletions

View File

@ -1,210 +1,7 @@
#!/usr/bin/python3
# SPDX-License-Identifier: AGPL-3.0-or-later
import argparse
import importlib
import inspect
import json
import logging
import os
import sys
import traceback
import types
import typing
import plinth.log
from plinth import cfg, module_loader
EXIT_SYNTAX = 10
EXIT_PERM = 20
logger = logging.getLogger(__name__)
def main():
"""Parse arguments."""
plinth.log.action_init()
parser = argparse.ArgumentParser()
parser.add_argument('module', help='Module to trigger action in')
parser.add_argument('action', help='Action to trigger in module')
parser.add_argument('--write-fd', type=int, default=1,
help='File descriptor to write output to')
parser.add_argument('--no-args', default=False, action='store_true',
help='Do not read arguments from stdin')
args = parser.parse_args()
try:
try:
arguments = {'args': [], 'kwargs': {}}
if not args.no_args:
input_ = sys.stdin.read()
if input_:
arguments = json.loads(input_)
except json.JSONDecodeError as exception:
raise SyntaxError('Arguments on stdin not JSON.') from exception
return_value = _call(args.module, args.action, arguments)
with os.fdopen(args.write_fd, 'w') as write_file_handle:
write_file_handle.write(json.dumps(return_value))
except PermissionError as exception:
logger.error(exception.args[0])
sys.exit(EXIT_PERM)
except SyntaxError as exception:
logger.error(exception.args[0])
sys.exit(EXIT_SYNTAX)
except TypeError as exception:
logger.error(exception.args[0])
sys.exit(EXIT_SYNTAX)
except Exception as exception:
logger.exception(exception)
sys.exit(1)
def _call(module_name, action_name, arguments):
"""Import the module and run action as superuser"""
if '.' in module_name:
raise SyntaxError('Invalid module name')
cfg.read()
if module_name == 'plinth':
import_path = 'plinth'
else:
import_path = module_loader.get_module_import_path(module_name)
try:
module = importlib.import_module(import_path + '.privileged')
except ModuleNotFoundError as exception:
raise SyntaxError('Specified module not found') from exception
try:
action = getattr(module, action_name)
except AttributeError as exception:
raise SyntaxError('Specified action not found') from exception
if not getattr(action, '_privileged', None):
raise SyntaxError('Specified action is not privileged action')
func = getattr(action, '__wrapped__')
_assert_valid_arguments(func, arguments)
try:
return_values = func(*arguments['args'], **arguments['kwargs'])
return_value = {'result': 'success', 'return': return_values}
except Exception as exception:
logger.exception('Error executing action: %s', exception)
return_value = {
'result': 'exception',
'exception': {
'module': type(exception).__module__,
'name': type(exception).__name__,
'args': exception.args,
'traceback': traceback.format_tb(exception.__traceback__)
}
}
return return_value
def _assert_valid_arguments(func, arguments):
"""Check the names, types and completeness of the arguments passed."""
# Check if arguments match types
if not isinstance(arguments, dict):
raise SyntaxError('Invalid arguments format')
if 'args' not in arguments or 'kwargs' not in arguments:
raise SyntaxError('Invalid arguments format')
args = arguments['args']
kwargs = arguments['kwargs']
if not isinstance(args, list) or not isinstance(kwargs, dict):
raise SyntaxError('Invalid arguments format')
argspec = inspect.getfullargspec(func)
if len(args) + len(kwargs) > len(argspec.args):
raise SyntaxError('Too many arguments')
no_defaults = len(argspec.args)
if argspec.defaults:
no_defaults -= len(argspec.defaults)
for key in argspec.args[len(args):no_defaults]:
if key not in kwargs:
raise SyntaxError(f'Argument not provided: {key}')
for key, value in kwargs.items():
if key not in argspec.args:
raise SyntaxError(f'Unknown argument: {key}')
if argspec.args.index(key) < len(args):
raise SyntaxError(f'Duplicate argument: {key}')
_assert_valid_type(f'arg {key}', value, argspec.annotations[key])
for index, arg in enumerate(args):
annotation = argspec.annotations[argspec.args[index]]
_assert_valid_type(f'arg #{index}', arg, annotation)
def _assert_valid_type(arg_name, value, annotation):
"""Assert that the type of argument value matches the annotation."""
if annotation == typing.Any:
return
NoneType = type(None)
if annotation == NoneType:
if value is not None:
raise TypeError('Expected None for {arg_name}')
return
basic_types = {bool, int, str, float}
if annotation in basic_types:
if not isinstance(value, annotation):
raise TypeError(
f'Expected type {annotation.__name__} for {arg_name}')
return
# 'int | str' or 'typing.Union[int, str]'
if (isinstance(annotation, types.UnionType)
or getattr(annotation, '__origin__', None) == typing.Union):
for arg in annotation.__args__:
try:
_assert_valid_type(arg_name, value, arg)
return
except TypeError:
pass
raise TypeError(f'Expected one of unioned types for {arg_name}')
# 'list[int]' or 'typing.List[int]'
if getattr(annotation, '__origin__', None) == list:
if not isinstance(value, list):
raise TypeError(f'Expected type list for {arg_name}')
for index, inner_item in enumerate(value):
_assert_valid_type(f'{arg_name}[{index}]', inner_item,
annotation.__args__[0])
return
# 'list[dict]' or 'typing.List[dict]'
if getattr(annotation, '__origin__', None) == dict:
if not isinstance(value, dict):
raise TypeError(f'Expected type dict for {arg_name}')
for inner_key, inner_value in value.items():
_assert_valid_type(f'{arg_name}[{inner_key}]', inner_key,
annotation.__args__[0])
_assert_valid_type(f'{arg_name}[{inner_value}]', inner_value,
annotation.__args__[1])
return
raise TypeError('Unsupported annotation type')
from plinth.actions import privileged_main
if __name__ == '__main__':
main()
privileged_main()

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
"""Framework to run specified actions with elevated privileges."""
import argparse
import functools
import importlib
import inspect
@ -8,9 +9,16 @@ import json
import logging
import os
import subprocess
import sys
import threading
import traceback
import types
import typing
from plinth import cfg
from plinth import cfg, log, module_loader
EXIT_SYNTAX = 10
EXIT_PERM = 20
logger = logging.getLogger(__name__)
@ -235,3 +243,189 @@ def _log_action(module_name, action_name, run_as_user, run_in_background):
prompt = f'({run_as_user})$' if run_as_user else '#'
suffix = '&' if run_in_background else ''
logger.info('%s %s..%s(…) %s', prompt, module_name, action_name, suffix)
def privileged_main():
"""Parse arguments for the program spawned as a privileged action."""
log.action_init()
parser = argparse.ArgumentParser()
parser.add_argument('module', help='Module to trigger action in')
parser.add_argument('action', help='Action to trigger in module')
parser.add_argument('--write-fd', type=int, default=1,
help='File descriptor to write output to')
parser.add_argument('--no-args', default=False, action='store_true',
help='Do not read arguments from stdin')
args = parser.parse_args()
try:
try:
arguments = {'args': [], 'kwargs': {}}
if not args.no_args:
input_ = sys.stdin.read()
if input_:
arguments = json.loads(input_)
except json.JSONDecodeError as exception:
raise SyntaxError('Arguments on stdin not JSON.') from exception
return_value = _privileged_call(args.module, args.action, arguments)
with os.fdopen(args.write_fd, 'w') as write_file_handle:
write_file_handle.write(json.dumps(return_value))
except PermissionError as exception:
logger.error(exception.args[0])
sys.exit(EXIT_PERM)
except SyntaxError as exception:
logger.error(exception.args[0])
sys.exit(EXIT_SYNTAX)
except TypeError as exception:
logger.error(exception.args[0])
sys.exit(EXIT_SYNTAX)
except Exception as exception:
logger.exception(exception)
sys.exit(1)
def _privileged_call(module_name, action_name, arguments):
"""Import the module and run action as superuser"""
if '.' in module_name:
raise SyntaxError('Invalid module name')
cfg.read()
if module_name == 'plinth':
import_path = 'plinth'
else:
import_path = module_loader.get_module_import_path(module_name)
try:
module = importlib.import_module(import_path + '.privileged')
except ModuleNotFoundError as exception:
raise SyntaxError('Specified module not found') from exception
try:
action = getattr(module, action_name)
except AttributeError as exception:
raise SyntaxError('Specified action not found') from exception
if not getattr(action, '_privileged', None):
raise SyntaxError('Specified action is not privileged action')
func = getattr(action, '__wrapped__')
_privileged_assert_valid_arguments(func, arguments)
try:
return_values = func(*arguments['args'], **arguments['kwargs'])
return_value = {'result': 'success', 'return': return_values}
except Exception as exception:
logger.exception('Error executing action: %s', exception)
return_value = {
'result': 'exception',
'exception': {
'module': type(exception).__module__,
'name': type(exception).__name__,
'args': exception.args,
'traceback': traceback.format_tb(exception.__traceback__)
}
}
return return_value
def _privileged_assert_valid_arguments(func, arguments):
"""Check the names, types and completeness of the arguments passed."""
# Check if arguments match types
if not isinstance(arguments, dict):
raise SyntaxError('Invalid arguments format')
if 'args' not in arguments or 'kwargs' not in arguments:
raise SyntaxError('Invalid arguments format')
args = arguments['args']
kwargs = arguments['kwargs']
if not isinstance(args, list) or not isinstance(kwargs, dict):
raise SyntaxError('Invalid arguments format')
argspec = inspect.getfullargspec(func)
if len(args) + len(kwargs) > len(argspec.args):
raise SyntaxError('Too many arguments')
no_defaults = len(argspec.args)
if argspec.defaults:
no_defaults -= len(argspec.defaults)
for key in argspec.args[len(args):no_defaults]:
if key not in kwargs:
raise SyntaxError(f'Argument not provided: {key}')
for key, value in kwargs.items():
if key not in argspec.args:
raise SyntaxError(f'Unknown argument: {key}')
if argspec.args.index(key) < len(args):
raise SyntaxError(f'Duplicate argument: {key}')
_privileged_assert_valid_type(f'arg {key}', value,
argspec.annotations[key])
for index, arg in enumerate(args):
annotation = argspec.annotations[argspec.args[index]]
_privileged_assert_valid_type(f'arg #{index}', arg, annotation)
def _privileged_assert_valid_type(arg_name, value, annotation):
"""Assert that the type of argument value matches the annotation."""
if annotation == typing.Any:
return
NoneType = type(None)
if annotation == NoneType:
if value is not None:
raise TypeError('Expected None for {arg_name}')
return
basic_types = {bool, int, str, float}
if annotation in basic_types:
if not isinstance(value, annotation):
raise TypeError(
f'Expected type {annotation.__name__} for {arg_name}')
return
# 'int | str' or 'typing.Union[int, str]'
if (isinstance(annotation, types.UnionType)
or getattr(annotation, '__origin__', None) == typing.Union):
for arg in annotation.__args__:
try:
_privileged_assert_valid_type(arg_name, value, arg)
return
except TypeError:
pass
raise TypeError(f'Expected one of unioned types for {arg_name}')
# 'list[int]' or 'typing.List[int]'
if getattr(annotation, '__origin__', None) == list:
if not isinstance(value, list):
raise TypeError(f'Expected type list for {arg_name}')
for index, inner_item in enumerate(value):
_privileged_assert_valid_type(f'{arg_name}[{index}]', inner_item,
annotation.__args__[0])
return
# 'list[dict]' or 'typing.List[dict]'
if getattr(annotation, '__origin__', None) == dict:
if not isinstance(value, dict):
raise TypeError(f'Expected type dict for {arg_name}')
for inner_key, inner_value in value.items():
_privileged_assert_valid_type(f'{arg_name}[{inner_key}]',
inner_key, annotation.__args__[0])
_privileged_assert_valid_type(f'{arg_name}[{inner_value}]',
inner_value, annotation.__args__[1])
return
raise TypeError('Unsupported annotation type')

View File

@ -9,7 +9,7 @@ from unittest.mock import patch
import pytest
from plinth.actions import privileged
from plinth import actions
actions_name = 'actions'
@ -17,10 +17,9 @@ actions_name = 'actions'
@patch('importlib.import_module')
@patch('plinth.module_loader.get_module_import_path')
@patch('os.getuid')
def test_call_syntax_checks(getuid, get_module_import_path, import_module,
actions_module):
def test_call_syntax_checks(getuid, get_module_import_path, import_module):
"""Test that calling a method results in proper syntax checks."""
call = actions_module._call
call = actions._privileged_call
# Module name validation
getuid.return_value = 0
@ -53,7 +52,7 @@ def test_call_syntax_checks(getuid, get_module_import_path, import_module,
call('test-module', 'func', {})
# Argument validation
@privileged
@actions.privileged
def func():
return 'foo'
@ -66,7 +65,7 @@ def test_call_syntax_checks(getuid, get_module_import_path, import_module,
assert return_value == {'result': 'success', 'return': 'foo'}
# Exception call
@privileged
@actions.privileged
def exception_func():
raise RuntimeError('foo exception')
@ -81,9 +80,9 @@ def test_call_syntax_checks(getuid, get_module_import_path, import_module,
assert isinstance(line, str)
def test_assert_valid_arguments(actions_module):
def test_assert_valid_arguments():
"""Test that checking valid arguments works."""
assert_valid = actions_module._assert_valid_arguments
assert_valid = actions._privileged_assert_valid_arguments
values = [
None, [], 10, {}, {
@ -139,9 +138,9 @@ def test_assert_valid_arguments(actions_module):
assert_valid(func, {'args': [1, '2'], 'kwargs': {'c': '3'}})
def test_assert_valid_type(actions_module):
def test_assert_valid_type():
"""Test that type validation works as expected."""
assert_valid = actions_module._assert_valid_type
assert_valid = actions._privileged_assert_valid_type
assert_valid(None, None, typing.Any)