#!/usr/bin/python3
# pylint: disable-msg=C0103
#
# SPDX-FileCopyrightText: 2004-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

"""Check if users are member of their primary group."""

import sys
from argparse import ArgumentParser

import ldap
from ldap.filter import filter_format

import univention.uldap
from univention.config_registry import ConfigRegistry


def info(msg, *args, **kwargs):
    """Print info."""
    print(msg % (args or kwargs))


def warn(msg, *args, **kwargs):
    """Print warning."""
    print('Warning: ' + (msg % (args or kwargs)), file=sys.stderr)
    warn.warnings += 1


warn.warnings = 0


def fatal(msg, *args, **kwargs):
    """Print error."""
    print('Error: ' + (msg % (args or kwargs)), file=sys.stderr)
    sys.exit(1)


def check_primary(conn, basedn):
    """Check if users are member of their primary group."""
    info("Checking if users are member of their primary group...")
    try:
        # GID's will only be found in posixAccount
        user_result = conn.search_s(basedn, ldap.SCOPE_SUBTREE, '(objectClass=posixAccount)', ['gidNumber', 'uid'])
    except ldap.NO_SUCH_OBJECT:
        fatal("ldap search in %s failed (no such base dn)", basedn)
    count_changes = 0
    for user_dn, account in user_result:
        user_uid = account['uid'][0].decode('utf-8')
        user_gid = account.get('gidNumber', [])[0].decode('utf-8')
        if not user_gid:
            warn("posixAccount without gidNumber: %s", user_dn)

        # search corresponding group
        group_result = conn.search_s(
            basedn, ldap.SCOPE_SUBTREE,
            filter_format('(&(objectClass=univentionGroup)(gidNumber=%s))', (user_gid,)),
            ['uniqueMember', 'memberUid'],
        )

        # there must be exactly one group with this gid
        if len(group_result) > 1:
            warn("found more than one univentionGroup for gidNumber=%s!", user_gid)
        elif len(group_result) < 1 and user_gid != "0":
            warn("found no univentionGroup for gidNumber=%s!", user_gid)
        # we change them all -- the user needs to delete all but one of them
        for group_dn, group in group_result:
            # look for the needed entry
            group_member_dns = [dn.decode('utf-8').lower() for dn in group.get('uniqueMember', [])]
            group_member_uids = [uid.decode('utf-8').lower() for uid in group.get('memberUid', [])]
            modlist = []
            if user_dn.lower() not in group_member_dns:
                modlist.append((ldap.MOD_ADD, 'uniqueMember', user_dn.encode('utf-8')))
            if user_uid.lower() not in group_member_uids:
                modlist.append((ldap.MOD_ADD, 'memberUid', user_uid.encode('utf-8')))
            # no entry found, going to add one
            if modlist:
                info("Adding uniqueMember and memberUid entry for '%s' in '%s'", user_dn, group_dn)
                try:
                    conn.modify_s(group_dn, modlist)
                    count_changes += 1
                except ldap.LDAPError:
                    warn("failed to modify group %s", group_dn)
    info("Checked %d posixAccounts, fixed %d issues.", len(user_result), count_changes)


def check_groups(conn, basedn):
    """Check if members of group exist."""
    info("Checking if group-members exist...")
    try:
        group_result = conn.search_s(basedn, ldap.SCOPE_SUBTREE, '(objectClass=posixGroup)', ['uniqueMember', 'memberUid'])
    except ldap.NO_SUCH_OBJECT:
        fatal("ldap search in %s failed (no such base dn)", basedn)

    count_changes = 0
    for group_dn, group in group_result:
        count_changes += check_groups_by_dn(conn, group_dn, group)
        count_changes += check_groups_by_uid(conn, basedn, group_dn, group)

    info("Checked %d posixGroups, fixed %d issues.", len(group_result), count_changes)


def check_groups_by_dn(conn, group_dn, group):
    """Check by 'uniqueMember'."""
    group_member_dns = [dn.decode('utf-8') for dn in group.get('uniqueMember', [])]
    count_changes = 0
    remmembers = set()
    for member_dn in group_member_dns:
        # Split uid=USER, cn=user,dc=FQDN
        try:
            member_filter, base = member_dn.split(',', 1)
        except ValueError:
            remmembers.add(member_dn)
            continue

        try:
            member_result = conn.search_s(base, ldap.SCOPE_ONELEVEL, member_filter, ['objectClass'])
        except ldap.LDAPError:
            warn("Manual: Search for member DN '%s' of group '%s' failed", member_dn, group_dn)
        else:
            if len(member_result) > 1:
                warn("Manual: Multiple members for DN '%s' of group '%s'", member_dn, group_dn)
            elif len(member_result) < 1:
                warn("No member for DN '%s', will be removed", member_dn)
                remmembers.add(member_dn)
    for member_dn in remmembers:
        info("Removing member DN '%s' from '%s'", member_dn, group_dn)
        modlist = [(ldap.MOD_DELETE, 'uniqueMember', member_dn.encode('utf-8'))]
        try:
            conn.modify_s(group_dn, modlist)
            count_changes += 1
        except ldap.LDAPError:
            warn("failed to remove DN '%s' from group '%s'", member_dn, group_dn)
    return count_changes


def check_groups_by_uid(conn, basedn, group_dn, group):
    """Check by 'memberUid'."""
    group_member_uids = [uid.decode('utf-8') for uid in group.get('memberUid', [])]
    count_changes = 0
    remmembers = set()
    for member_uid in group_member_uids:
        try:
            member_result = conn.search_s(basedn, ldap.SCOPE_SUBTREE, filter_format('(uid=%s)', (member_uid,)), ['objectClass'])
        except ldap.LDAPError:
            warn("Manual: Search for member UID '%s' of group '%s' failed", member_uid, group_dn)
        else:
            if len(member_result) > 1:
                warn("Manual: Multiple members for UID '%s' of group '%s'", member_uid, group_dn)
            elif len(member_result) < 1:
                warn("No member for UID '%s', will be removed", member_uid)
                remmembers.add(member_uid)
    for member_uid in remmembers:
        info("Removing member UID '%s' from '%s'", member_uid, group_dn)
        modlist = [(ldap.MOD_DELETE, 'memberUid', member_uid.encode('utf-8'))]
        try:
            conn.modify_s(group_dn, modlist)
            count_changes += 1
        except ldap.LDAPError:
            warn("Failed to remove UID '%s' from group '%s'", member_uid, group_dn)
    return count_changes


def main():
    # type: () -> None
    """Check group membership."""
    parser = ArgumentParser(description=__doc__)
    parser.add_argument("-b", "--base-dn", dest="basedn", action="store", help="ldap base DN for user search")
    parser.add_argument("-c", "--check", dest="check", action="store_true", help="Only check, do not modify")
    options = parser.parse_args()

    ucr = ConfigRegistry()
    ucr.load()
    basedn = ucr['ldap/base']

    conn = univention.uldap.getAdminConnection().lo

    if options.basedn:
        basedn = options.basedn
    if options.check:
        conn.modify_s = lambda dn, modlist: None

    check_primary(conn, basedn)
    check_groups(conn, basedn)
    if warn.warnings:
        info("There were %d warning(s)!", warn.warnings)
        sys.exit(2)
    else:
        sys.exit(0)


if __name__ == '__main__':
    main()
