package org.apache.gobblin.crypto;

import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import javax.crypto.Cipher;
import javax.crypto.CipherInputStream;
import javax.crypto.CipherOutputStream;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import javax.xml.bind.DatatypeConverter;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.io.IOUtils;
import org.apache.gobblin.codec.Base64Codec;
import org.apache.gobblin.codec.StreamCodec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/gobblin/crypto/RotatingAESCodec.class */
public class RotatingAESCodec implements StreamCodec {
    private static final Logger log = LoggerFactory.getLogger(RotatingAESCodec.class);
    private static final int AES_KEY_LEN = 16;
    private static final String TAG = "aes_rotating";
    private final Random random = new Random();
    private final CredentialStore credentialStore;
    private volatile Map<Integer, KeyRecord> keyRecords_cache;
    private volatile KeyRecord[] keyRecords_cache_arr;

    /* loaded from: input_file:org/apache/gobblin/crypto/RotatingAESCodec$DecodingStreamInstance.class */
    private class DecodingStreamInstance {
        private final InputStream origStream;
        private final byte[] buffer = new byte[32];
        private final Cipher cipher;

        DecodingStreamInstance(InputStream inputStream) throws IOException {
            this.origStream = inputStream;
            Integer readKey = readKey();
            KeyRecord key = RotatingAESCodec.this.getKey(readKey);
            if (key == null) {
                throw new IOException("Cannot load key " + String.valueOf(readKey) + " which is specified in input stream");
            }
            try {
                byte[] readIv = readIv();
                this.cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
                if (readIv != null) {
                    this.cipher.init(2, key.getSecretKey(), new IvParameterSpec(readIv));
                } else {
                    this.cipher.init(2, key.getSecretKey());
                }
            } catch (InvalidAlgorithmParameterException e) {
                throw new IllegalStateException("Failed to initialize IV", e);
            } catch (InvalidKeyException e2) {
                throw new IllegalStateException("Failed to parse key from keystore", e2);
            } catch (NoSuchAlgorithmException | NoSuchPaddingException e3) {
                throw new IllegalStateException("Failed to load AES which should never happen", e3);
            }
        }

        InputStream wrapInputStream() throws IOException {
            return new CipherInputStream(new Base64Codec().decodeInputStream(this.origStream), this.cipher);
        }

        private Integer readKey() throws IOException {
            IOUtils.readFully(this.origStream, this.buffer, 0, 4);
            try {
                return Integer.valueOf(new String(this.buffer, 0, 4, StandardCharsets.UTF_8));
            } catch (NumberFormatException e) {
                throw new IOException("Expected to be able to parse first 4 bytes of stream as an ASCII keyId");
            }
        }

        private byte[] readIv() throws IOException {
            IOUtils.readFully(this.origStream, this.buffer, 0, 3);
            try {
                Integer valueOf = Integer.valueOf(new String(this.buffer, 0, 3, StandardCharsets.UTF_8));
                if (valueOf.intValue() < 0 || valueOf.intValue() > this.buffer.length) {
                    throw new IOException("Corrupted data suspected; expected IVLen to be between 0 and " + String.valueOf(this.buffer.length) + ", read " + String.valueOf(valueOf));
                }
                if (valueOf.intValue() == 0) {
                    return null;
                }
                byte[] bArr = new byte[valueOf.intValue()];
                IOUtils.readFully(this.origStream, bArr, 0, bArr.length);
                return Base64.decodeBase64(bArr);
            } catch (NumberFormatException e) {
                throw new IOException("Expected to parse next 3 bytes of stream as an IV len");
            }
        }
    }

    /* loaded from: input_file:org/apache/gobblin/crypto/RotatingAESCodec$EncodingStreamInstance.class */
    static class EncodingStreamInstance {
        private final OutputStream origStream;
        private final KeyRecord secretKey;
        private Cipher cipher;
        private String base64Iv;
        private boolean headerWritten = false;

        EncodingStreamInstance(KeyRecord keyRecord, OutputStream outputStream) {
            this.secretKey = keyRecord;
            this.origStream = outputStream;
        }

        OutputStream wrapOutputStream() throws IOException {
            initCipher();
            final CipherOutputStream cipherOutputStream = new CipherOutputStream(getBase64Stream(this.origStream), this.cipher);
            return new FilterOutputStream(this.origStream) { // from class: org.apache.gobblin.crypto.RotatingAESCodec.EncodingStreamInstance.1
                @Override // java.io.FilterOutputStream, java.io.OutputStream
                public void write(int i) throws IOException {
                    EncodingStreamInstance.this.writeHeaderIfNecessary();
                    cipherOutputStream.write(i);
                }

                @Override // java.io.FilterOutputStream, java.io.OutputStream
                public void write(byte[] bArr) throws IOException {
                    EncodingStreamInstance.this.writeHeaderIfNecessary();
                    cipherOutputStream.write(bArr);
                }

                @Override // java.io.FilterOutputStream, java.io.OutputStream
                public void write(byte[] bArr, int i, int i2) throws IOException {
                    EncodingStreamInstance.this.writeHeaderIfNecessary();
                    cipherOutputStream.write(bArr, i, i2);
                }

                @Override // java.io.FilterOutputStream, java.io.OutputStream, java.io.Closeable, java.lang.AutoCloseable
                public void close() throws IOException {
                    cipherOutputStream.close();
                }
            };
        }

        private OutputStream getBase64Stream(OutputStream outputStream) throws IOException {
            return new Base64Codec().encodeOutputStream(outputStream);
        }

        private void initCipher() {
            if (this.origStream == null) {
                throw new IllegalStateException("Can't initCipher stream before encodeOutputStream() has been called!");
            }
            try {
                this.cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
                this.cipher.init(1, this.secretKey.getSecretKey());
                this.base64Iv = DatatypeConverter.printBase64Binary(this.cipher.getIV());
                this.headerWritten = false;
            } catch (InvalidKeyException e) {
                throw new IllegalStateException("Key " + this.secretKey.getKeyId() + " is illegal - please check credential store");
            } catch (NoSuchAlgorithmException | NoSuchPaddingException e2) {
                throw new IllegalStateException("Error creating AES algorithm? Should always exist in JRE");
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void writeHeaderIfNecessary() throws IOException {
            if (this.headerWritten) {
                return;
            }
            this.origStream.write(String.format("%04d%03d%s", Integer.valueOf(this.secretKey.getKeyId()), Integer.valueOf(this.base64Iv.length()), this.base64Iv).getBytes(StandardCharsets.UTF_8));
            this.headerWritten = true;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/gobblin/crypto/RotatingAESCodec$KeyRecord.class */
    public static class KeyRecord {
        private final int keyId;
        private final SecretKey secretKey;

        KeyRecord(int i, SecretKey secretKey) {
            this.keyId = i;
            this.secretKey = secretKey;
        }

        int getKeyId() {
            return this.keyId;
        }

        SecretKey getSecretKey() {
            return this.secretKey;
        }
    }

    public RotatingAESCodec(CredentialStore credentialStore) {
        this.credentialStore = credentialStore;
    }

    public OutputStream encodeOutputStream(OutputStream outputStream) throws IOException {
        return new EncodingStreamInstance(selectRandomKey(), outputStream).wrapOutputStream();
    }

    public InputStream decodeInputStream(InputStream inputStream) throws IOException {
        return new DecodingStreamInstance(inputStream).wrapInputStream();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public synchronized KeyRecord getKey(Integer num) {
        fillKeyRecords();
        return this.keyRecords_cache.get(num);
    }

    private synchronized KeyRecord selectRandomKey() {
        KeyRecord[] keyRecords = getKeyRecords();
        if (keyRecords.length == 0) {
            throw new IllegalStateException("Couldn't find any valid keys in store!");
        }
        return keyRecords[this.random.nextInt(keyRecords.length)];
    }

    private synchronized KeyRecord[] getKeyRecords() {
        fillKeyRecords();
        return this.keyRecords_cache_arr;
    }

    private synchronized void fillKeyRecords() {
        if (this.keyRecords_cache == null) {
            this.keyRecords_cache = new HashMap();
            for (Map.Entry entry : this.credentialStore.getAllEncodedKeys().entrySet()) {
                if (((byte[]) entry.getValue()).length != AES_KEY_LEN) {
                    log.debug("Skipping keyId {} because it is length {}; expected {}", new Object[]{entry.getKey(), Integer.valueOf(((byte[]) entry.getValue()).length), Integer.valueOf(AES_KEY_LEN)});
                } else {
                    try {
                        Integer valueOf = Integer.valueOf(Integer.parseInt((String) entry.getKey()));
                        this.keyRecords_cache.put(valueOf, new KeyRecord(valueOf.intValue(), new SecretKeySpec((byte[]) entry.getValue(), "AES")));
                    } catch (NumberFormatException e) {
                        log.debug("Skipping keyId {} because this algorithm can only use numeric key ids", entry.getKey());
                    }
                }
            }
            this.keyRecords_cache_arr = (KeyRecord[]) this.keyRecords_cache.values().toArray(new KeyRecord[this.keyRecords_cache.size()]);
        }
    }

    public String getTag() {
        return TAG;
    }
}
