#!/usr/bin/python -O
# vim: noet :

from __future__ import print_function

import errno
import re
import signal
import stat
import sys
import textwrap
import time
from optparse import OptionParser, OptionValueError

try:
	import portage
except ImportError:
	from os import path as osp
	sys.path.insert(0, osp.join(osp.dirname(osp.dirname(osp.realpath(__file__))), "pym"))
	import portage

from portage import os
from portage.util import writemsg

if sys.hexversion >= 0x3000000:
	long = int

class WorldHandler(object):

	short_desc = "Fix problems in the world file"

	def name():
		return "world"
	name = staticmethod(name)

	def __init__(self):
		self.invalid = []
		self.not_installed = []
		self.invalid_category = []
		self.okay = []
		from portage._sets import load_default_config
		setconfig = load_default_config(portage.settings,
			portage.db[portage.settings['EROOT']])
		self._sets = setconfig.getSets()

	def _check_world(self, onProgress):
		categories = set(portage.settings.categories)
		eroot = portage.settings['EROOT']
		self.world_file = os.path.join(eroot, portage.const.WORLD_FILE)
		self.found = os.access(self.world_file, os.R_OK)
		vardb = portage.db[eroot]["vartree"].dbapi

		from portage._sets import SETPREFIX
		sets = self._sets
		world_atoms = list(sets["selected"])
		maxval = len(world_atoms)
		if onProgress:
			onProgress(maxval, 0)
		for i, atom in enumerate(world_atoms):
			if not isinstance(atom, portage.dep.Atom):
				if atom.startswith(SETPREFIX):
					s = atom[len(SETPREFIX):]
					if s in sets:
						self.okay.append(atom)
					else:
						self.not_installed.append(atom)
				else:
					self.invalid.append(atom)
				if onProgress:
					onProgress(maxval, i+1)
				continue
			okay = True
			if not vardb.match(atom):
				self.not_installed.append(atom)
				okay = False
			if portage.catsplit(atom.cp)[0] not in categories:
				self.invalid_category.append(atom)
				okay = False
			if okay:
				self.okay.append(atom)
			if onProgress:
				onProgress(maxval, i+1)

	def check(self, onProgress=None):
		self._check_world(onProgress)
		errors = []
		if self.found:
			errors += ["'%s' is not a valid atom" % x for x in self.invalid]
			errors += ["'%s' is not installed" % x for x in self.not_installed]
			errors += ["'%s' has a category that is not listed in /etc/portage/categories" % x for x in self.invalid_category]
		else:
			errors.append(self.world_file + " could not be opened for reading")
		return errors

	def fix(self, onProgress=None):
		world_set = self._sets["selected"]
		world_set.lock()
		try:
			world_set.load() # maybe it's changed on disk
			before = set(world_set)
			self._check_world(onProgress)
			after = set(self.okay)
			errors = []
			if before != after:
				try:
					world_set.replace(self.okay)
				except portage.exception.PortageException:
					errors.append("%s could not be opened for writing" % \
						self.world_file)
			return errors
		finally:
			world_set.unlock()

class BinhostHandler(object):

	short_desc = "Generate a metadata index for binary packages"

	def name():
		return "binhost"
	name = staticmethod(name)

	def __init__(self):
		eroot = portage.settings['EROOT']
		self._bintree = portage.db[eroot]["bintree"]
		self._bintree.populate()
		self._pkgindex_file = self._bintree._pkgindex_file
		self._pkgindex = self._bintree._load_pkgindex()

	def _need_update(self, cpv, data):

		if "MD5" not in data:
			return True

		size = data.get("SIZE")
		if size is None:
			return True

		mtime = data.get("MTIME")
		if mtime is None:
			return True

		pkg_path = self._bintree.getname(cpv)
		try:
			s = os.lstat(pkg_path)
		except OSError as e:
			if e.errno not in (errno.ENOENT, errno.ESTALE):
				raise
			# We can't update the index for this one because
			# it disappeared.
			return False

		try:
			if long(mtime) != s[stat.ST_MTIME]:
				return True
			if long(size) != long(s.st_size):
				return True
		except ValueError:
			return True

		return False

	def check(self, onProgress=None):
		missing = []
		cpv_all = self._bintree.dbapi.cpv_all()
		cpv_all.sort()
		maxval = len(cpv_all)
		if onProgress:
			onProgress(maxval, 0)
		pkgindex = self._pkgindex
		missing = []
		metadata = {}
		for d in pkgindex.packages:
			metadata[d["CPV"]] = d
		for i, cpv in enumerate(cpv_all):
			d = metadata.get(cpv)
			if not d or self._need_update(cpv, d):
				missing.append(cpv)
			if onProgress:
				onProgress(maxval, i+1)
		errors = ["'%s' is not in Packages" % cpv for cpv in missing]
		stale = set(metadata).difference(cpv_all)
		for cpv in stale:
			errors.append("'%s' is not in the repository" % cpv)
		return errors

	def fix(self, onProgress=None):
		bintree = self._bintree
		cpv_all = self._bintree.dbapi.cpv_all()
		cpv_all.sort()
		missing = []
		maxval = 0
		if onProgress:
			onProgress(maxval, 0)
		pkgindex = self._pkgindex
		missing = []
		metadata = {}
		for d in pkgindex.packages:
			metadata[d["CPV"]] = d

		for i, cpv in enumerate(cpv_all):
			d = metadata.get(cpv)
			if not d or self._need_update(cpv, d):
				missing.append(cpv)

		stale = set(metadata).difference(cpv_all)
		if missing or stale:
			from portage import locks
			pkgindex_lock = locks.lockfile(
				self._pkgindex_file, wantnewlockfile=1)
			try:
				# Repopulate with lock held.
				bintree._populate()
				cpv_all = self._bintree.dbapi.cpv_all()
				cpv_all.sort()

				pkgindex = bintree._load_pkgindex()
				self._pkgindex = pkgindex

				metadata = {}
				for d in pkgindex.packages:
					metadata[d["CPV"]] = d

				# Recount missing packages, with lock held.
				del missing[:]
				for i, cpv in enumerate(cpv_all):
					d = metadata.get(cpv)
					if not d or self._need_update(cpv, d):
						missing.append(cpv)

				maxval = len(missing)
				for i, cpv in enumerate(missing):
					try:
						metadata[cpv] = bintree._pkgindex_entry(cpv)
					except portage.exception.InvalidDependString:
						writemsg("!!! Invalid binary package: '%s'\n" % \
							bintree.getname(cpv), noiselevel=-1)

					if onProgress:
						onProgress(maxval, i+1)

				for cpv in set(metadata).difference(
					self._bintree.dbapi.cpv_all()):
					del metadata[cpv]

				# We've updated the pkgindex, so set it to
				# repopulate when necessary.
				bintree.populated = False

				del pkgindex.packages[:]
				pkgindex.packages.extend(metadata.values())
				from portage.util import atomic_ofstream
				f = atomic_ofstream(self._pkgindex_file)
				try:
					self._pkgindex.write(f)
				finally:
					f.close()
			finally:
				locks.unlockfile(pkgindex_lock)

		if onProgress:
			if maxval == 0:
				maxval = 1
			onProgress(maxval, maxval)
		return None

class MoveHandler(object):

	def __init__(self, tree, porttree):
		self._tree = tree
		self._portdb = porttree.dbapi
		self._update_keys = ["DEPEND", "RDEPEND", "PDEPEND", "PROVIDE"]
		self._master_repo = \
			self._portdb.getRepositoryName(self._portdb.porttree_root)

	def _grab_global_updates(self):
		from portage.update import grab_updates, parse_updates
		retupdates = {}
		errors = []

		for repo_name in self._portdb.getRepositories():
			repo = self._portdb.getRepositoryPath(repo_name)
			updpath = os.path.join(repo, "profiles", "updates")
			if not os.path.isdir(updpath):
				continue

			try:
				rawupdates = grab_updates(updpath)
			except portage.exception.DirectoryNotFound:
				rawupdates = []
			upd_commands = []
			for mykey, mystat, mycontent in rawupdates:
				commands, errors = parse_updates(mycontent)
				upd_commands.extend(commands)
				errors.extend(errors)
			retupdates[repo_name] = upd_commands

		if self._master_repo in retupdates:
			retupdates['DEFAULT'] = retupdates[self._master_repo]

		return retupdates, errors

	def check(self, onProgress=None):
		allupdates, errors = self._grab_global_updates()
		# Matching packages and moving them is relatively fast, so the
		# progress bar is updated in indeterminate mode.
		match = self._tree.dbapi.match
		aux_get = self._tree.dbapi.aux_get
		if onProgress:
			onProgress(0, 0)
		for repo, updates in allupdates.items():
			if repo == 'DEFAULT':
				continue
			if not updates:
				continue

			def repo_match(repository):
				return repository == repo or \
					(repo == self._master_repo and \
					repository not in allupdates)

			for i, update_cmd in enumerate(updates):
				if update_cmd[0] == "move":
					origcp, newcp = update_cmd[1:]
					for cpv in match(origcp):
						if repo_match(aux_get(cpv, ["repository"])[0]):
							errors.append("'%s' moved to '%s'" % (cpv, newcp))
				elif update_cmd[0] == "slotmove":
					pkg, origslot, newslot = update_cmd[1:]
					for cpv in match(pkg):
						slot, prepo = aux_get(cpv, ["SLOT", "repository"])
						if slot == origslot and repo_match(prepo):
							errors.append("'%s' slot moved from '%s' to '%s'" % \
								(cpv, origslot, newslot))
				if onProgress:
					onProgress(0, 0)

		# Searching for updates in all the metadata is relatively slow, so this
		# is where the progress bar comes out of indeterminate mode.
		cpv_all = self._tree.dbapi.cpv_all()
		cpv_all.sort()
		maxval = len(cpv_all)
		aux_update = self._tree.dbapi.aux_update
		meta_keys = self._update_keys + ['repository']
		from portage.update import update_dbentries
		if onProgress:
			onProgress(maxval, 0)
		for i, cpv in enumerate(cpv_all):
			metadata = dict(zip(meta_keys, aux_get(cpv, meta_keys)))
			repository = metadata.pop('repository')
			try:
				updates = allupdates[repository]
			except KeyError:
				try:
					updates = allupdates['DEFAULT']
				except KeyError:
					continue
			if not updates:
				continue
			metadata_updates = update_dbentries(updates, metadata)
			if metadata_updates:
				errors.append("'%s' has outdated metadata" % cpv)
			if onProgress:
				onProgress(maxval, i+1)
		return errors

	def fix(self, onProgress=None):
		allupdates, errors = self._grab_global_updates()
		# Matching packages and moving them is relatively fast, so the
		# progress bar is updated in indeterminate mode.
		move = self._tree.dbapi.move_ent
		slotmove = self._tree.dbapi.move_slot_ent
		if onProgress:
			onProgress(0, 0)
		for repo, updates in allupdates.items():
			if repo == 'DEFAULT':
				continue
			if not updates:
				continue

			def repo_match(repository):
				return repository == repo or \
					(repo == self._master_repo and \
					repository not in allupdates)

			for i, update_cmd in enumerate(updates):
				if update_cmd[0] == "move":
					move(update_cmd, repo_match=repo_match)
				elif update_cmd[0] == "slotmove":
					slotmove(update_cmd, repo_match=repo_match)
				if onProgress:
					onProgress(0, 0)

		# Searching for updates in all the metadata is relatively slow, so this
		# is where the progress bar comes out of indeterminate mode.
		self._tree.dbapi.update_ents(allupdates, onProgress=onProgress)
		return errors

class MoveInstalled(MoveHandler):

	short_desc = "Perform package move updates for installed packages"

	def name():
		return "moveinst"
	name = staticmethod(name)
	def __init__(self):
		eroot = portage.settings['EROOT']
		MoveHandler.__init__(self, portage.db[eroot]["vartree"], portage.db[eroot]["porttree"])

class MoveBinary(MoveHandler):

	short_desc = "Perform package move updates for binary packages"

	def name():
		return "movebin"
	name = staticmethod(name)
	def __init__(self):
		eroot = portage.settings['EROOT']
		MoveHandler.__init__(self, portage.db[eroot]["bintree"], portage.db[eroot]['porttree'])

class VdbKeyHandler(object):
	def name():
		return "vdbkeys"
	name = staticmethod(name)

	def __init__(self):
		self.list = portage.db[portage.settings["EROOT"]]["vartree"].dbapi.cpv_all()
		self.missing = []
		self.keys = ["HOMEPAGE", "SRC_URI", "KEYWORDS", "DESCRIPTION"]
		
		for p in self.list:
			mydir = os.path.join(portage.settings["EROOT"], portage.const.VDB_PATH, p)+os.sep
			ismissing = True
			for k in self.keys:
				if os.path.exists(mydir+k):
					ismissing = False
					break
			if ismissing:
				self.missing.append(p)
		
	def check(self):
		return ["%s has missing keys" % x for x in self.missing]
	
	def fix(self):
	
		errors = []
	
		for p in self.missing:
			mydir = os.path.join(portage.settings["EROOT"], portage.const.VDB_PATH, p)+os.sep
			if not os.access(mydir+"environment.bz2", os.R_OK):
				errors.append("Can't access %s" % (mydir+"environment.bz2"))
			elif not os.access(mydir, os.W_OK):
				errors.append("Can't create files in %s" % mydir)
			else:
				env = os.popen("bzip2 -dcq "+mydir+"environment.bz2", "r")
				envlines = env.read().split("\n")
				env.close()
				for k in self.keys:
					s = [l for l in envlines if l.startswith(k+"=")]
					if len(s) > 1:
						errors.append("multiple matches for %s found in %senvironment.bz2" % (k, mydir))
					elif len(s) == 0:
						s = ""
					else:
						s = s[0].split("=",1)[1]
						s = s.lstrip("$").strip("\'\"")
						s = re.sub("(\\\\[nrt])+", " ", s)
						s = " ".join(s.split()).strip()
						if s != "":
							try:
								keyfile = open(mydir+os.sep+k, "w")
								keyfile.write(s+"\n")
								keyfile.close()
							except (IOError, OSError) as e:
								errors.append("Could not write %s, reason was: %s" % (mydir+k, e))
		
		return errors

class ProgressHandler(object):
	def __init__(self):
		self.curval = 0
		self.maxval = 0
		self.last_update = 0
		self.min_display_latency = 0.2

	def onProgress(self, maxval, curval):
		self.maxval = maxval
		self.curval = curval
		cur_time = time.time()
		if cur_time - self.last_update >= self.min_display_latency:
			self.last_update = cur_time
			self.display()

	def display(self):
		raise NotImplementedError(self)

class CleanResume(object):

	short_desc = "Discard emerge --resume merge lists"

	def name():
		return "cleanresume"
	name = staticmethod(name)

	def check(self, onProgress=None):
		messages = []
		mtimedb = portage.mtimedb
		resume_keys = ("resume", "resume_backup")
		maxval = len(resume_keys)
		if onProgress:
			onProgress(maxval, 0)
		for i, k in enumerate(resume_keys):
			try:
				d = mtimedb.get(k)
				if d is None:
					continue
				if not isinstance(d, dict):
					messages.append("unrecognized resume list: '%s'" % k)
					continue
				mergelist = d.get("mergelist")
				if mergelist is None or not hasattr(mergelist, "__len__"):
					messages.append("unrecognized resume list: '%s'" % k)
					continue
				messages.append("resume list '%s' contains %d packages" % \
					(k, len(mergelist)))
			finally:
				if onProgress:
					onProgress(maxval, i+1)
		return messages

	def fix(self, onProgress=None):
		delete_count = 0
		mtimedb = portage.mtimedb
		resume_keys = ("resume", "resume_backup")
		maxval = len(resume_keys)
		if onProgress:
			onProgress(maxval, 0)
		for i, k in enumerate(resume_keys):
			try:
				if mtimedb.pop(k, None) is not None:
					delete_count += 1
			finally:
				if onProgress:
					onProgress(maxval, i+1)
		if delete_count:
			mtimedb.commit()

def emaint_main(myargv):

	# Similar to emerge, emaint needs a default umask so that created
	# files (such as the world file) have sane permissions.
	os.umask(0o22)

	# TODO: Create a system that allows external modules to be added without
	#       the need for hard coding.
	modules = {
		"world" : WorldHandler,
		"binhost":BinhostHandler,
		"moveinst":MoveInstalled,
		"movebin":MoveBinary,
		"cleanresume":CleanResume
	}

	module_names = list(modules)
	module_names.sort()
	module_names.insert(0, "all")

	def exclusive(option, *args, **kw):
		var = kw.get("var", None)
		if var is None:
			raise ValueError("var not specified to exclusive()")
		if getattr(parser, var, ""):
			raise OptionValueError("%s and %s are exclusive options" % (getattr(parser, var), option))
		setattr(parser, var, str(option))


	usage = "usage: emaint [options] COMMAND"

	desc = "The emaint program provides an interface to system health " + \
		"checks and maintenance. See the emaint(1) man page " + \
		"for additional information about the following commands:"

	usage += "\n\n"
	for line in textwrap.wrap(desc, 65):
		usage += "%s\n" % line
	usage += "\n"
	usage += "  %s" % "all".ljust(15) + \
		"Perform all supported commands\n"
	for m in module_names[1:]:
		usage += "  %s%s\n" % (m.ljust(15), modules[m].short_desc)

	parser = OptionParser(usage=usage, version=portage.VERSION)
	parser.add_option("-c", "--check", help="check for problems",
		action="callback", callback=exclusive, callback_kwargs={"var":"action"})
	parser.add_option("-f", "--fix", help="attempt to fix problems",
		action="callback", callback=exclusive, callback_kwargs={"var":"action"})
	parser.action = None


	(options, args) = parser.parse_args(args=myargv)
	if len(args) != 1:
		parser.error("Incorrect number of arguments")
	if args[0] not in module_names:
		parser.error("%s target is not a known target" % args[0])

	if parser.action:
		action = parser.action
	else:
		print("Defaulting to --check")
		action = "-c/--check"

	if args[0] == "all":
		tasks = modules.values()
	else:
		tasks = [modules[args[0]]]


	if action == "-c/--check":
		status = "Checking %s for problems"
		func = "check"
	else:
		status = "Attempting to fix %s"
		func = "fix"

	isatty = os.environ.get('TERM') != 'dumb' and sys.stdout.isatty()
	for task in tasks:
		print(status % task.name())
		inst = task()
		onProgress = None
		if isatty:
			progressBar = portage.output.TermProgressBar()
			progressHandler = ProgressHandler()
			onProgress = progressHandler.onProgress
			def display():
				progressBar.set(progressHandler.curval, progressHandler.maxval)
			progressHandler.display = display
			def sigwinch_handler(signum, frame):
				lines, progressBar.term_columns = \
					portage.output.get_term_size()
			signal.signal(signal.SIGWINCH, sigwinch_handler)
		result = getattr(inst, func)(onProgress=onProgress)
		if isatty:
			# make sure the final progress is displayed
			progressHandler.display()
			print()
			signal.signal(signal.SIGWINCH, signal.SIG_DFL)
		if result:
			print()
			print("\n".join(result))
			print("\n")

	print("Finished")

if __name__ == "__main__":
	emaint_main(sys.argv[1:])
