418 lines
11 KiB
Python
418 lines
11 KiB
Python
"""This script contains the actual auditing tests.
|
|
|
|
It should not be imported directly, but should be run by the test_audit
|
|
module with arguments identifying each test.
|
|
|
|
"""
|
|
|
|
import contextlib
|
|
import os
|
|
import sys
|
|
|
|
|
|
class TestHook:
|
|
"""Used in standard hook tests to collect any logged events.
|
|
|
|
Should be used in a with block to ensure that it has no impact
|
|
after the test completes.
|
|
"""
|
|
|
|
def __init__(self, raise_on_events=None, exc_type=RuntimeError):
|
|
self.raise_on_events = raise_on_events or ()
|
|
self.exc_type = exc_type
|
|
self.seen = []
|
|
self.closed = False
|
|
|
|
def __enter__(self, *a):
|
|
sys.addaudithook(self)
|
|
return self
|
|
|
|
def __exit__(self, *a):
|
|
self.close()
|
|
|
|
def close(self):
|
|
self.closed = True
|
|
|
|
@property
|
|
def seen_events(self):
|
|
return [i[0] for i in self.seen]
|
|
|
|
def __call__(self, event, args):
|
|
if self.closed:
|
|
return
|
|
self.seen.append((event, args))
|
|
if event in self.raise_on_events:
|
|
raise self.exc_type("saw event " + event)
|
|
|
|
|
|
# Simple helpers, since we are not in unittest here
|
|
def assertEqual(x, y):
|
|
if x != y:
|
|
raise AssertionError(f"{x!r} should equal {y!r}")
|
|
|
|
|
|
def assertIn(el, series):
|
|
if el not in series:
|
|
raise AssertionError(f"{el!r} should be in {series!r}")
|
|
|
|
|
|
def assertNotIn(el, series):
|
|
if el in series:
|
|
raise AssertionError(f"{el!r} should not be in {series!r}")
|
|
|
|
|
|
def assertSequenceEqual(x, y):
|
|
if len(x) != len(y):
|
|
raise AssertionError(f"{x!r} should equal {y!r}")
|
|
if any(ix != iy for ix, iy in zip(x, y)):
|
|
raise AssertionError(f"{x!r} should equal {y!r}")
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def assertRaises(ex_type):
|
|
try:
|
|
yield
|
|
assert False, f"expected {ex_type}"
|
|
except BaseException as ex:
|
|
if isinstance(ex, AssertionError):
|
|
raise
|
|
assert type(ex) is ex_type, f"{ex} should be {ex_type}"
|
|
|
|
|
|
def test_basic():
|
|
with TestHook() as hook:
|
|
sys.audit("test_event", 1, 2, 3)
|
|
assertEqual(hook.seen[0][0], "test_event")
|
|
assertEqual(hook.seen[0][1], (1, 2, 3))
|
|
|
|
|
|
def test_block_add_hook():
|
|
# Raising an exception should prevent a new hook from being added,
|
|
# but will not propagate out.
|
|
with TestHook(raise_on_events="sys.addaudithook") as hook1:
|
|
with TestHook() as hook2:
|
|
sys.audit("test_event")
|
|
assertIn("test_event", hook1.seen_events)
|
|
assertNotIn("test_event", hook2.seen_events)
|
|
|
|
|
|
def test_block_add_hook_baseexception():
|
|
# Raising BaseException will propagate out when adding a hook
|
|
with assertRaises(BaseException):
|
|
with TestHook(
|
|
raise_on_events="sys.addaudithook", exc_type=BaseException
|
|
) as hook1:
|
|
# Adding this next hook should raise BaseException
|
|
with TestHook() as hook2:
|
|
pass
|
|
|
|
|
|
def test_marshal():
|
|
import marshal
|
|
o = ("a", "b", "c", 1, 2, 3)
|
|
payload = marshal.dumps(o)
|
|
|
|
with TestHook() as hook:
|
|
assertEqual(o, marshal.loads(marshal.dumps(o)))
|
|
|
|
try:
|
|
with open("test-marshal.bin", "wb") as f:
|
|
marshal.dump(o, f)
|
|
with open("test-marshal.bin", "rb") as f:
|
|
assertEqual(o, marshal.load(f))
|
|
finally:
|
|
os.unlink("test-marshal.bin")
|
|
|
|
actual = [(a[0], a[1]) for e, a in hook.seen if e == "marshal.dumps"]
|
|
assertSequenceEqual(actual, [(o, marshal.version)] * 2)
|
|
|
|
actual = [a[0] for e, a in hook.seen if e == "marshal.loads"]
|
|
assertSequenceEqual(actual, [payload])
|
|
|
|
actual = [e for e, a in hook.seen if e == "marshal.load"]
|
|
assertSequenceEqual(actual, ["marshal.load"])
|
|
|
|
|
|
def test_pickle():
|
|
import pickle
|
|
|
|
class PicklePrint:
|
|
def __reduce_ex__(self, p):
|
|
return str, ("Pwned!",)
|
|
|
|
payload_1 = pickle.dumps(PicklePrint())
|
|
payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))
|
|
|
|
# Before we add the hook, ensure our malicious pickle loads
|
|
assertEqual("Pwned!", pickle.loads(payload_1))
|
|
|
|
with TestHook(raise_on_events="pickle.find_class") as hook:
|
|
with assertRaises(RuntimeError):
|
|
# With the hook enabled, loading globals is not allowed
|
|
pickle.loads(payload_1)
|
|
# pickles with no globals are okay
|
|
pickle.loads(payload_2)
|
|
|
|
|
|
def test_monkeypatch():
|
|
class A:
|
|
pass
|
|
|
|
class B:
|
|
pass
|
|
|
|
class C(A):
|
|
pass
|
|
|
|
a = A()
|
|
|
|
with TestHook() as hook:
|
|
# Catch name changes
|
|
C.__name__ = "X"
|
|
# Catch type changes
|
|
C.__bases__ = (B,)
|
|
# Ensure bypassing __setattr__ is still caught
|
|
type.__dict__["__bases__"].__set__(C, (B,))
|
|
# Catch attribute replacement
|
|
C.__init__ = B.__init__
|
|
# Catch attribute addition
|
|
C.new_attr = 123
|
|
# Catch class changes
|
|
a.__class__ = B
|
|
|
|
actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"]
|
|
assertSequenceEqual(
|
|
[(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual
|
|
)
|
|
|
|
|
|
def test_open():
|
|
# SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open()
|
|
try:
|
|
import ssl
|
|
|
|
load_dh_params = ssl.create_default_context().load_dh_params
|
|
except ImportError:
|
|
load_dh_params = None
|
|
|
|
# Try a range of "open" functions.
|
|
# All of them should fail
|
|
with TestHook(raise_on_events={"open"}) as hook:
|
|
for fn, *args in [
|
|
(open, sys.argv[2], "r"),
|
|
(open, sys.executable, "rb"),
|
|
(open, 3, "wb"),
|
|
(open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1),
|
|
(load_dh_params, sys.argv[2]),
|
|
]:
|
|
if not fn:
|
|
continue
|
|
with assertRaises(RuntimeError):
|
|
fn(*args)
|
|
|
|
actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]]
|
|
actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]]
|
|
assertSequenceEqual(
|
|
[
|
|
i
|
|
for i in [
|
|
(sys.argv[2], "r"),
|
|
(sys.executable, "r"),
|
|
(3, "w"),
|
|
(sys.argv[2], "w"),
|
|
(sys.argv[2], "rb") if load_dh_params else None,
|
|
]
|
|
if i is not None
|
|
],
|
|
actual_mode,
|
|
)
|
|
assertSequenceEqual([], actual_flag)
|
|
|
|
|
|
def test_cantrace():
|
|
traced = []
|
|
|
|
def trace(frame, event, *args):
|
|
if frame.f_code == TestHook.__call__.__code__:
|
|
traced.append(event)
|
|
|
|
old = sys.settrace(trace)
|
|
try:
|
|
with TestHook() as hook:
|
|
# No traced call
|
|
eval("1")
|
|
|
|
# No traced call
|
|
hook.__cantrace__ = False
|
|
eval("2")
|
|
|
|
# One traced call
|
|
hook.__cantrace__ = True
|
|
eval("3")
|
|
|
|
# Two traced calls (writing to private member, eval)
|
|
hook.__cantrace__ = 1
|
|
eval("4")
|
|
|
|
# One traced call (writing to private member)
|
|
hook.__cantrace__ = 0
|
|
finally:
|
|
sys.settrace(old)
|
|
|
|
assertSequenceEqual(["call"] * 4, traced)
|
|
|
|
|
|
def test_mmap():
|
|
import mmap
|
|
|
|
with TestHook() as hook:
|
|
mmap.mmap(-1, 8)
|
|
assertEqual(hook.seen[0][1][:2], (-1, 8))
|
|
|
|
|
|
def test_excepthook():
|
|
def excepthook(exc_type, exc_value, exc_tb):
|
|
if exc_type is not RuntimeError:
|
|
sys.__excepthook__(exc_type, exc_value, exc_tb)
|
|
|
|
def hook(event, args):
|
|
if event == "sys.excepthook":
|
|
if not isinstance(args[2], args[1]):
|
|
raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})")
|
|
if args[0] != excepthook:
|
|
raise ValueError(f"Expected {args[0]} == {excepthook}")
|
|
print(event, repr(args[2]))
|
|
|
|
sys.addaudithook(hook)
|
|
sys.excepthook = excepthook
|
|
raise RuntimeError("fatal-error")
|
|
|
|
|
|
def test_unraisablehook():
|
|
from _testcapi import write_unraisable_exc
|
|
|
|
def unraisablehook(hookargs):
|
|
pass
|
|
|
|
def hook(event, args):
|
|
if event == "sys.unraisablehook":
|
|
if args[0] != unraisablehook:
|
|
raise ValueError(f"Expected {args[0]} == {unraisablehook}")
|
|
print(event, repr(args[1].exc_value), args[1].err_msg)
|
|
|
|
sys.addaudithook(hook)
|
|
sys.unraisablehook = unraisablehook
|
|
write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None)
|
|
|
|
|
|
def test_winreg():
|
|
from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE
|
|
|
|
def hook(event, args):
|
|
if not event.startswith("winreg."):
|
|
return
|
|
print(event, *args)
|
|
|
|
sys.addaudithook(hook)
|
|
|
|
k = OpenKey(HKEY_LOCAL_MACHINE, "Software")
|
|
EnumKey(k, 0)
|
|
try:
|
|
EnumKey(k, 10000)
|
|
except OSError:
|
|
pass
|
|
else:
|
|
raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail")
|
|
|
|
kv = k.Detach()
|
|
CloseKey(kv)
|
|
|
|
|
|
def test_socket():
|
|
import socket
|
|
|
|
def hook(event, args):
|
|
if event.startswith("socket."):
|
|
print(event, *args)
|
|
|
|
sys.addaudithook(hook)
|
|
|
|
socket.gethostname()
|
|
|
|
# Don't care if this fails, we just want the audit message
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
try:
|
|
# Don't care if this fails, we just want the audit message
|
|
sock.bind(('127.0.0.1', 8080))
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
sock.close()
|
|
|
|
|
|
def test_gc():
|
|
import gc
|
|
|
|
def hook(event, args):
|
|
if event.startswith("gc."):
|
|
print(event, *args)
|
|
|
|
sys.addaudithook(hook)
|
|
|
|
gc.get_objects(generation=1)
|
|
|
|
x = object()
|
|
y = [x]
|
|
|
|
gc.get_referrers(x)
|
|
gc.get_referents(y)
|
|
|
|
|
|
def test_http_client():
|
|
import http.client
|
|
|
|
def hook(event, args):
|
|
if event.startswith("http.client."):
|
|
print(event, *args[1:])
|
|
|
|
sys.addaudithook(hook)
|
|
|
|
conn = http.client.HTTPConnection('www.python.org')
|
|
try:
|
|
conn.request('GET', '/')
|
|
except OSError:
|
|
print('http.client.send', '[cannot send]')
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def test_sqlite3():
|
|
import sqlite3
|
|
|
|
def hook(event, *args):
|
|
if event.startswith("sqlite3."):
|
|
print(event, *args)
|
|
|
|
sys.addaudithook(hook)
|
|
cx1 = sqlite3.connect(":memory:")
|
|
cx2 = sqlite3.Connection(":memory:")
|
|
|
|
# Configured without --enable-loadable-sqlite-extensions
|
|
if hasattr(sqlite3.Connection, "enable_load_extension"):
|
|
cx1.enable_load_extension(False)
|
|
try:
|
|
cx1.load_extension("test")
|
|
except sqlite3.OperationalError:
|
|
pass
|
|
else:
|
|
raise RuntimeError("Expected sqlite3.load_extension to fail")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from test.support import suppress_msvcrt_asserts
|
|
|
|
suppress_msvcrt_asserts()
|
|
|
|
test = sys.argv[1]
|
|
globals()[test]()
|