#!/usr/bin/python3
# SPDX-FileCopyrightText: 2004-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

"""Check if users are member of their primary group"""

import sys
from argparse import ArgumentParser
from logging import getLogger

import univention.admin.modules
import univention.admin.objects
import univention.admin.uexceptions
import univention.config_registry
import univention.logging


log = getLogger('ADMIN')
lo = None
position = None
grp2childgrps = {}


class groupRecursionDetected(univention.admin.uexceptions.base):
    message = 'circular group recursion detected'

    def __init__(self, recursionlist):
        super().__init__()
        self.recursionlist = recursionlist


def get_ldap_connection(baseDN):
    try:
        return univention.admin.uldap.getAdminConnection()
    except Exception:
        return univention.admin.uldap.getMachineConnection()


def checkChilds(grp_module, dn, parents, verbose=False, exception=False):
    if dn not in grp2childgrps:
        grpobj = univention.admin.objects.get(grp_module, None, lo, position='', dn=dn, attr=None)
        grpobj.open()
        childs = grpobj['nestedGroup']
        grp2childgrps[dn] = childs
    else:
        childs = grp2childgrps[dn]

    new_parents = [*parents, dn.lower()]
    for childgrp in childs:
        if verbose:
            print('%s+--> %s' % ('|    ' * (len(parents) + 1), childgrp))
        if childgrp.lower() in new_parents:
            recursionlist = [*new_parents[new_parents.index(childgrp.lower()):], childgrp]
            raise groupRecursionDetected(recursionlist)

        checkChilds(grp_module, childgrp, new_parents, verbose)


def main():
    # type: () -> None
    parser = ArgumentParser(description=__doc__)
    parser.add_argument("-v", "--verbose", help="print debug output", dest="verbose", action="store_true")
    options = parser.parse_args()

    global lo, position

    univention.logging.basicConfig(filename='/var/log/univention/check_group_recursion.log', univention_debug_level=1)

    configRegistry = univention.config_registry.ConfigRegistry()
    configRegistry.load()
    basedn = configRegistry['ldap/base']

    univention.admin.modules.update()
    lo, position = get_ldap_connection(basedn)
    grp_module = univention.admin.modules._get('groups/group')
    univention.admin.modules.init(lo, position, grp_module)

    recursionCnt = 0

    grpobjlist = univention.admin.modules.lookup(grp_module, None, lo, scope='sub', superordinate=None, base=basedn, filter=None)
    print('Number of groups: %d' % len(grpobjlist))
    log.error('Testing %d groups...', len(grpobjlist))
    for i, group in enumerate(grpobjlist):
        if options.verbose:
            print()
            print('|--> %s' % group.dn)
        else:
            print('Testing group #%d\r' % i, end='')
            sys.stdout.flush()

        try:
            checkChilds(grp_module, group.dn, [], options.verbose)
        except groupRecursionDetected as e:
            txtring = ''
            for dn in e.recursionlist:
                txtring += '--> %s\n' % dn
            log.error('Recursion detected: %s\n%s', group.dn, txtring)
            print()
            print('Recursion detected:')
            for dn in e.recursionlist:
                print('--> %s' % dn)
            recursionCnt += 1

    log.error('Tests have been finished. %d group(s) with circular recursion found.', recursionCnt)
    if options.verbose:
        print()
        print('Tests have been finished. %d group(s) with circular recursion found.' % recursionCnt)


if __name__ == '__main__':
    main()
