#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# Univention Printer Assignment
# Copyright 2007-2025 Univention GmbH
#
# http://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <http://www.gnu.org/licenses/>.

import argparse
import os
import shutil
import sys
import tempfile

import univention.admin.modules
import univention.admin.objects
import ldap
from univention.config_registry import ucr

baseDN = ucr.get('ldap/base', None)
vbs_filename = ucr.get('printer/assignment/vbs/template', '/usr/share/univention-printer-assignment/printer-assignment-template.vbs')
debug = ucr.is_true('printer/assignment/update/debug', False)

cache_printers = {}
cache_groups = {}
cache_hosts = {}


def log_debug(msg):
	if debug:
		print('DEBUG: %s' % msg)


def log_error(msg):
	print(msg, file=sys.stderr)


def main():
	global debug
	desc = '''
%(prog)s updates the printer assignment logon scripts.
The given list of DNs (either as direct argument or by a file) defines
which logon script has to be updated. The backlogfile has to contain one
DN per line.
'''

	parser = argparse.ArgumentParser(description=desc)
	parser.add_argument('-f', '--file', action='store', dest='backlogfile', default=None, help='path to the backlog file')
	parser.add_argument('-d', '--debug', action='store_true', dest='debug', default=False, help='enable debugging output')
	options, args = parser.parse_known_args()

	if not options.backlogfile and not args:
		log_error('At least one DN has to be specified as argument or via backlog file')
		sys.exit(1)

	if options.debug:
		debug = options.debug

	# create missing directories
	target_dirs = prepare_target_directories()

	log_debug('trying to get LDAP connection with machine credentials')
	lo = univention.admin.uldap.getMachineConnection(ldap_master=False)[0]

	# load file or use arguments
	if options.backlogfile:
		backlog_set = load_backlog(options.backlogfile)
	else:
		backlog_set = args

	# load specified objects
	load_objects(lo, backlog_set)

	# get list of loaded hosts
	host_dn_list = cache_hosts.keys()

	# process all loaded hosts
	for host_dn in host_dn_list:
		process_host(lo, host_dn, target_dirs)


def prepare_target_directories():
	""" This function determines the target directories and creates missing directories.
		The return value is a set of the target directories. """

	target_dirs = set()

	special_netlogon_path = ucr.get('printer/assignment/netlogon/path', '').strip().rstrip('/')
	samba_netlogon_path = ucr.get('samba/share/netlogon/path', '').strip().rstrip('/')
	if special_netlogon_path and special_netlogon_path.startswith('/'):
		target_dirs.add(special_netlogon_path)
	elif samba_netlogon_path:
		target_dirs.add(os.path.join(samba_netlogon_path, special_netlogon_path))
	else:
		target_dirs.add('/var/lib/samba/netlogon/%s' % special_netlogon_path)
		target_dirs.add('/var/lib/samba/sysvol/%s/scripts/%s' % (ucr.get('kerberos/realm', '').lower(), special_netlogon_path))

	for path in target_dirs:
		if not os.path.isdir(path):
			os.makedirs(path)

	return target_dirs


def load_backlog(filename):
	""" load the backlog file and return a set of dn to be processed """
	log_debug('loading backlog file %s' % filename)
	dn_set = set(open(filename, 'r').read().splitlines())
	log_debug('done')
	return dn_set


def fetch_object(lo, dn):
	try:
		return lo.get(dn)
	except Exception as e:
		log_error('ERROR: failed to load object %s: %s' % (dn, e))
		return None


def load_printer(lo, dn, obj=None, recursive=True, reload=False, ignore_missing=False):
	# skip if already loaded and cache is active
	if dn in cache_printers and not reload:
		return
	# fetch object
	if not obj:
		obj = fetch_object(lo, dn)
	if not ignore_missing:
		if not obj:
			return
		if b'univentionPrinter' not in obj.get('objectClass', []):
			return
		# add to cache
		cache_printers[dn] = obj
		log_debug('loaded printer %s' % dn)

	if recursive or ignore_missing:
		# fetch all groups that refer to this printer
		ldap_filter = ldap.filter.filter_format(
			'(&'
				'(objectClass=univentionGroup)'
			'(|'
				'(univentionAssignedPrinter=%s)'
				'(univentionAssignedPrinterDefault=%s)'
			')'
			')',
			[dn, dn],
		)
		try:
			search_result = lo.search(filter=ldap_filter, base=baseDN)
		except univention.admin.uexceptions.noObject:
			search_result = []
		for (grpdn, grpobj) in search_result:
			load_group(lo, grpdn, obj=grpobj)

		# fetch all hosts that refer to this printer
		ldap_filter = ldap.filter.filter_format(
			'(&'
				'(objectClass=univentionHost)'
				'(objectClass=univentionWindows)'
			'(|'
				'(univentionAssignedPrinter=%s)'
				'(univentionAssignedPrinterDefault=%s)'
			')'
			')',
			[dn, dn],
		)
		try:
			search_result = lo.search(filter=ldap_filter, base=baseDN)
		except univention.admin.uexceptions.noObject:
			search_result = []
		for (hostdn, hostobj) in search_result:
			load_host(lo, hostdn, obj=hostobj)


def load_group(lo, dn, obj=None, recursive=True, reload=False):
	# skip if already loaded and cache is active
	if dn in cache_groups and not reload:
		return
	# fetch object
	if not obj:
		obj = fetch_object(lo, dn)
	if not obj:
		return
	if b'univentionGroup' not in obj.get('objectClass', []):
		return
	# add to cache
	cache_groups[dn] = obj
	log_debug('loaded group %s' % dn)

	if recursive:
		# fetch all hosts that are mentioned within this group
		for memberdn in obj.get('uniqueMember', []):
			memberdn = memberdn.decode("utf-8")
			load_host(lo, memberdn, obj=None)


def load_host(lo, dn, obj=None, reload=False):
	# skip if already loaded and cache is active
	if dn in cache_hosts and not reload:
		return
	# fetch object
	if not obj:
		obj = fetch_object(lo, dn)
	if not obj:
		return
	if b'univentionHost' not in obj.get('objectClass', []) or \
			b'univentionWindows' not in obj.get('objectClass', []):
		return
	# add to cache
	cache_hosts[dn] = obj
	log_debug('loaded host %s' % dn)


def load_objects(lo, dnlist):
	already_checked = set()
	for dn in dnlist:
		# skip already loaded objects
		if dn in already_checked:
			continue
		obj = fetch_object(lo, dn)
		if not obj:
			# Check special case:
			# If missing object was a printer object, then all groups with still existing references to the printer have to be loaded.
			# Side note: This has to be done only once if the first missing object in the backlog is checked; all further checks are redundant
			#            This may be implemented in future.
			load_printer(lo, dn, ignore_missing=True)
		elif b'univentionHost' in obj.get('objectClass', []) and \
				b'univentionWindows' in obj.get('objectClass', []):
			load_host(lo, dn, obj=obj)
		elif b'univentionGroup' in obj.get('objectClass', []):
			load_group(lo, dn, obj=obj)
		elif b'univentionPrinter' in obj.get('objectClass', []):
			load_printer(lo, dn, obj=obj)
		else:
			log_debug('found unknown object type\nDN: %s\nobjectClass: %r\n' % (dn, obj.get('objectClass')))
		already_checked.add(dn)
	log_debug('list of affected hosts:\n  %s' % '\n  '.join(cache_hosts.keys()))


def get_printer_unc(lo, printer_dn):
	log_debug('fetching printer path for %s' % printer_dn)
	load_printer(lo, printer_dn, recursive=False)  # make sure that the cache contains the printer object - NOOP if already in cache
	if printer_dn not in cache_printers:
		return None
	host = cache_printers[printer_dn].get('univentionPrinterSpoolHost', [b''])[0].decode("utf-8")     # TODO FIXME uses only the first spool host
	queuename = cache_printers[printer_dn].get('univentionPrinterSambaName', [b''])[0].decode("utf-8")
	settings_file = cache_printers[printer_dn].get('univentionAssignedPrinterSettingsFile', [b''])[0].decode("utf-8")
	if not queuename:
		queuename = cache_printers[printer_dn].get('cn', [b''])[0].decode("utf-8")
		if not queuename:
			log_error('ERROR: cannot determine queuename for %s' % printer_dn)
			return None
	ret = '\\\\%s\\%s' % (host, queuename)
	if settings_file:
		ret += ":%s" % settings_file
	return ret


def get_assigned_printers_from_obj(lo, obj):
	printer_list = []
	default_printer = None

	for printer_dn in obj.get('univentionAssignedPrinter', []):
		printer_dn = printer_dn.decode("utf-8")
		printer_unc = get_printer_unc(lo, printer_dn)
		if printer_unc:
			printer_list.append(printer_unc)
		else:
			log_error('ERROR: cannot load printer %s (reference found at cn=%s)' % (printer_dn, obj.get('cn', [b''])[0].decode("utf-8")))

	for printer_dn in obj.get('univentionAssignedPrinterDefault', []):
		printer_dn = printer_dn.decode("utf-8")
		printer_unc = get_printer_unc(lo, printer_dn)
		if printer_unc:
			default_printer = printer_unc
		else:
			log_error('ERROR: cannot load printer %s (reference found at cn=%s)' % (printer_dn, obj.get('cn', [b''])[0].decode("utf-8")))

	return (printer_list, default_printer)


def process_host(lo, dn, target_dirs):
	log_debug('processing host %s' % dn)

	assign_list = set()
	default_printer = None

	if dn not in cache_hosts:
		log_error('ERROR: host is not in cache: %s' % dn)

	ldap_filter = ldap.filter.filter_format(
		'(&'
		'(objectClass=univentionGroup)'
		'(uniqueMember=%s))',
		[dn],
	)
	try:
		search_result = lo.searchDn(filter=ldap_filter, base=baseDN)
	except univention.admin.uexceptions.noObject:
		search_result = []
	for grpdn in search_result:
		log_debug('checking group %s' % grpdn)
		load_group(lo, grpdn, recursive=False)  # make sure that the cache contains the group object - NOOP if already in cache
		if grpdn not in cache_groups:
			log_error('WARNING: cannot load group %s' % grpdn)
			continue

		printer_list, new_default_printer = get_assigned_printers_from_obj(lo, cache_groups[grpdn])
		assign_list.update(printer_list)
		if new_default_printer:
			assign_list.add(new_default_printer)
			default_printer = new_default_printer
		log_debug('group %s returned %s' % (grpdn, printer_list))

	printer_list, new_default_printer = get_assigned_printers_from_obj(lo, cache_hosts[dn])
	assign_list.update(printer_list)
	if new_default_printer:
		assign_list.add(new_default_printer)
		default_printer = new_default_printer

	log_debug('host %s' % dn)
	log_debug('     default: %s' % default_printer)
	log_debug('     printers: %s' % '	'.join(assign_list))

	settings = {
		'flagRemoveAllPrinters': '0',
		'flagShowDebug': '0',
		'printerList': ' '.join(assign_list),
		'defaultPrinter': '',
	}
	if ucr.is_true('printer/assignment/vbs/removeall', False):
		settings['flagRemoveAllPrinters'] = '1'

	if ucr.is_true("printer/assignment/vbs/debug", False):
		settings['flagShowDebug'] = '1'

	if ucr.is_true("printer/assignment/vbs/setdefaultprinter", True) and default_printer:
		settings['defaultPrinter'] = default_printer.split(":", 1)[0]

	settings['printUIEntryOptions'] = ucr.get("printer/assignment/printuientry/options", "")

	# open template, read content and replace variables
	content = open(vbs_filename, 'r').read() % settings
	for target_dir in target_dirs:
		target_fn = os.path.join(target_dir, '%s.vbs' % cache_hosts[dn].get('cn', [b'UNKNOWN'])[0].decode("utf-8"))
		# create and write temporary file
		temp_fd, temp_fn = tempfile.mkstemp(dir=target_dir)
		log_debug('Writing temporary file %s' % temp_fn)
		with open(temp_fn, 'w') as fd:
			fd.write(content)
		fd.close()
		os.chmod(temp_fn, 0o755)
		# replace old file
		log_debug('Replacing file %s' % target_fn)
		shutil.move(temp_fn, target_fn.lower())
		# close temporary file
		os.close(temp_fd)


if __name__ == '__main__':
	main()
