/*
 * Decompiled with CFR 0.152.
 */
package org.keycloak.protocol.saml;

import java.security.PrivateKey;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.xml.security.encryption.EncryptedData;
import org.apache.xml.security.encryption.EncryptedKey;
import org.apache.xml.security.encryption.EncryptionMethod;
import org.apache.xml.security.exceptions.XMLSecurityException;
import org.apache.xml.security.keys.KeyInfo;
import org.apache.xml.security.keys.content.KeyName;
import org.keycloak.crypto.KeyUse;
import org.keycloak.crypto.KeyWrapper;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.protocol.saml.SAMLEncryptionAlgorithms;
import org.keycloak.saml.processing.core.util.XMLEncryptionUtil;

public class SAMLDecryptionKeysLocator
implements XMLEncryptionUtil.DecryptionKeyLocator {
    private final KeycloakSession session;
    private final RealmModel realm;
    private final String requestedAlgorithm;

    public SAMLDecryptionKeysLocator(KeycloakSession session, RealmModel realm, String requestedAlgorithm) {
        this.session = session;
        this.realm = realm;
        this.requestedAlgorithm = requestedAlgorithm;
    }

    private List<String> getKeyNames(KeyInfo keyInfo) {
        LinkedList<String> keyNames = new LinkedList<String>();
        try {
            for (int i = 0; i < keyInfo.lengthKeyName(); ++i) {
                KeyName keyName = keyInfo.itemKeyName(i);
                if (keyName == null) continue;
                keyNames.add(keyName.getKeyName());
            }
        }
        catch (XMLSecurityException e) {
            throw new IllegalStateException("Cannot load keyNames from document", e);
        }
        return keyNames;
    }

    private Predicate<KeyWrapper> hasMatchingAlgorithm(String algorithm) {
        SAMLEncryptionAlgorithms usedAlgorithm = SAMLEncryptionAlgorithms.forXMLEncIdentifier(algorithm);
        if (usedAlgorithm == null) {
            throw new IllegalStateException("Keycloak does not support encryption keys for given algorithm: " + algorithm);
        }
        return keyWrapper -> Objects.equals(keyWrapper.getAlgorithmOrDefault(), usedAlgorithm.getKeycloakIdentifier());
    }

    public List<PrivateKey> getKeys(EncryptedData encryptedData) {
        KeyInfo keyInfo = encryptedData.getKeyInfo();
        if (keyInfo == null) {
            throw new IllegalStateException("EncryptedData does not contain KeyInfo");
        }
        Stream<KeyWrapper> keysStream = this.session.keys().getKeysStream(this.realm).filter(key -> key.getStatus().isEnabled() && KeyUse.ENC.equals((Object)key.getUse()));
        if (this.requestedAlgorithm != null && !this.requestedAlgorithm.trim().isEmpty()) {
            keysStream = keysStream.filter(keyWrapper -> Objects.equals(keyWrapper.getAlgorithmOrDefault(), this.requestedAlgorithm));
        }
        if (keyInfo.containsKeyName()) {
            List<String> keyNames = this.getKeyNames(keyInfo);
            keysStream = keysStream.filter(keyWrapper -> keyNames.contains(keyWrapper.getKid()));
        }
        try {
            EncryptedKey encryptedKey = keyInfo.itemEncryptedKey(0);
            if (encryptedKey != null) {
                EncryptionMethod encryptionMethod = encryptedKey.getEncryptionMethod();
                if (encryptionMethod == null) {
                    throw new IllegalArgumentException("KeyInfo does not contain encryption method");
                }
                String algorithm = encryptionMethod.getAlgorithm();
                if (algorithm == null) {
                    throw new IllegalArgumentException("Not able to find algorithm for given encryption method");
                }
                keysStream = keysStream.filter(this.hasMatchingAlgorithm(algorithm));
            }
        }
        catch (XMLSecurityException e) {
            throw new IllegalArgumentException("EncryptedData does not contain KeyInfo ", e);
        }
        return keysStream.map(KeyWrapper::getPrivateKey).filter(Objects::nonNull).map(PrivateKey.class::cast).collect(Collectors.toList());
    }
}

