#!/usr/bin/env python
# Copyright 2018-2019 Gentoo Authors
# Distributed under the terms of the GNU General Public License v2

import functools
import os
import signal
import sys


KILL_SIGNALS = (
	signal.SIGINT,
	signal.SIGTERM,
	signal.SIGHUP,
)

def forward_kill_signal(main_child_pid, signum, frame):
	os.kill(main_child_pid, signum)


def main(argv):
	if len(argv) < 2:
		return 'Usage: {} <main-child-pid> or <binary> <argv0> [arg]..'.format(argv[0])

	if len(argv) == 2:
		# The child process is init (pid 1) in a child pid namespace, and
		# the current process supervises from within the global pid namespace
		# (forwarding signals to init and forwarding exit status to the parent
		# process).
		main_child_pid = int(argv[1])
	else:
		# The current process is init (pid 1) in a child pid namespace.
		binary = argv[1]
		args = argv[2:]

		main_child_pid = os.fork()
		if main_child_pid == 0:
			os.execv(binary, args)

	sig_handler = functools.partial(forward_kill_signal, main_child_pid)
	for signum in KILL_SIGNALS:
		signal.signal(signum, sig_handler)

	# wait for child processes
	while True:
		try:
			pid, status = os.wait()
		except OSError as e:
			if e.errno == errno.EINTR:
				continue
			raise
		if pid == main_child_pid:
			if os.WIFEXITED(status):
				return os.WEXITSTATUS(status)
			elif os.WIFSIGNALED(status):
				signal.signal(os.WTERMSIG(status), signal.SIG_DFL)
				os.kill(os.getpid(), os.WTERMSIG(status))
			# go to the unreachable place
			break

	# this should never be reached
	return 127


if __name__ == '__main__':
	sys.exit(main(sys.argv))
