package com.quasiris.qsf.commons.ai.embedding;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.quasiris.qsf.commons.ai.dto.Document;
import com.quasiris.qsf.commons.ai.dto.TextVector;
import com.quasiris.qsf.commons.ai.dto.TextVectorDocument;
import com.quasiris.qsf.commons.exception.NormalizerNotSupportedException;
import com.quasiris.qsf.commons.nlp.SentenceSplitter;
import com.quasiris.qsf.commons.text.normalizer.TextNormalizerService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpEntity;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/quasiris/qsf/commons/ai/embedding/BertAsAServiceEncoder.class */
public class BertAsAServiceEncoder implements TextEmbeddingEncoder {
    private static final Logger logger = LoggerFactory.getLogger(BertAsAServiceEncoder.class);
    private static String BULK_FIELD = "_bulk";
    private String baseUrl;
    private Integer timeout;
    private ObjectMapper objectMapper = new ObjectMapper();

    public BertAsAServiceEncoder(String str, Integer num) {
        this.baseUrl = str;
        this.timeout = num;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v11, types: [java.util.List] */
    @Override // com.quasiris.qsf.commons.ai.embedding.TextEmbeddingEncoder
    public List<TextVector> embed(String str, TextNormalizerService textNormalizerService, boolean z) {
        ArrayList arrayList = new ArrayList();
        List<TextVectorDocument> embedTextBulk = embedTextBulk(Arrays.asList(str), textNormalizerService, z);
        if (embedTextBulk.size() == 1) {
            arrayList = (List) embedTextBulk.get(0).getFields().get(BULK_FIELD);
        }
        return arrayList;
    }

    private List<TextVectorDocument> embedTextBulk(List<String> list, TextNormalizerService textNormalizerService, boolean z) {
        HttpPost httpPost;
        CloseableHttpClient createDefault;
        SentenceSplitter sentenceSplitter = new SentenceSplitter();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (String str : list) {
            List<String> split = z ? sentenceSplitter.split(str) : Arrays.asList(str);
            TextVectorDocument textVectorDocument = new TextVectorDocument();
            textVectorDocument.getFields().put(BULK_FIELD, new ArrayList());
            for (String str2 : split) {
                String normalize = textNormalizerService != null ? textNormalizerService.normalize(str2) : str2;
                if (StringUtils.isNotBlank(normalize)) {
                    TextVector textVector = new TextVector(str2, normalize, null);
                    textVectorDocument.getFields().get(BULK_FIELD).add(textVector);
                    arrayList2.add(textVector);
                    arrayList.add(normalize);
                }
            }
            arrayList3.add(textVectorDocument);
        }
        try {
            HttpEntity buildEntity = buildEntity(arrayList);
            httpPost = new HttpPost(this.baseUrl);
            httpPost.setEntity(buildEntity);
            httpPost.addHeader("Content-Type", "application/json");
            RequestConfig.Builder custom = RequestConfig.custom();
            custom.setConnectTimeout(this.timeout.intValue());
            custom.setConnectionRequestTimeout(this.timeout.intValue());
            custom.setSocketTimeout(this.timeout.intValue());
            httpPost.setConfig(custom.build());
            try {
                createDefault = HttpClients.createDefault();
            } catch (Exception e) {
                logger.warn("Something gone wrong in GET document for BertAsAServiceEncoder!", e);
            }
        } catch (JsonProcessingException e2) {
            logger.warn(e2.getMessage());
        }
        try {
            CloseableHttpResponse execute = createDefault.execute(httpPost);
            try {
                HttpEntity entity = execute.getEntity();
                if (entity != null && execute.getStatusLine().getStatusCode() == 200) {
                    Map map = (Map) this.objectMapper.readValue(EntityUtils.toString(entity), Map.class);
                    if (map.containsKey("result")) {
                        List list2 = (List) map.get("result");
                        if (list2.size() != arrayList2.size()) {
                            throw new Exception("Input and output vectors does not match!");
                        }
                        for (int i = 0; i < list2.size(); i++) {
                            ((TextVector) arrayList2.get(i)).setVector((Double[]) ((List) list2.get(i)).stream().toArray(i2 -> {
                                return new Double[i2];
                            }));
                        }
                    }
                }
                if (execute != null) {
                    execute.close();
                }
                if (createDefault != null) {
                    createDefault.close();
                }
                return arrayList3;
            } catch (Throwable th) {
                if (execute != null) {
                    try {
                        execute.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        } catch (Throwable th3) {
            if (createDefault != null) {
                try {
                    createDefault.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    @Override // com.quasiris.qsf.commons.ai.embedding.TextEmbeddingEncoder
    public TextVectorDocument embed(Document<String> document, TextNormalizerService textNormalizerService, boolean z) {
        TextVectorDocument textVectorDocument = new TextVectorDocument(document.getId());
        List<TextVectorDocument> embedBulk = embedBulk(Arrays.asList(document), textNormalizerService, z);
        if (embedBulk.size() == 1) {
            textVectorDocument = embedBulk.get(0);
        }
        return textVectorDocument;
    }

    @Override // com.quasiris.qsf.commons.ai.embedding.TextEmbeddingEncoder
    public TextVectorDocument embedDoc(Document<List<String>> document, TextNormalizerService textNormalizerService) throws NormalizerNotSupportedException {
        throw new NotImplementedException("This method not supported yet!");
    }

    @Override // com.quasiris.qsf.commons.ai.embedding.TextEmbeddingEncoder
    public List<TextVectorDocument> embedBulk(List<Document<String>> list, TextNormalizerService textNormalizerService, boolean z) {
        ArrayList arrayList = new ArrayList();
        for (Document<String> document : list) {
            TextVectorDocument textVectorDocument = new TextVectorDocument(document.getId());
            ArrayList arrayList2 = new ArrayList();
            Iterator<Map.Entry<String, String>> it = document.getFields().entrySet().iterator();
            while (it.hasNext()) {
                arrayList2.add(it.next().getValue());
            }
            List<TextVectorDocument> embedTextBulk = embedTextBulk(arrayList2, textNormalizerService, z);
            if (document.getFields().values().size() == embedTextBulk.size()) {
                int i = 0;
                Iterator<Map.Entry<String, String>> it2 = document.getFields().entrySet().iterator();
                while (it2.hasNext()) {
                    textVectorDocument.getFields().put(it2.next().getKey(), embedTextBulk.get(i).getFields().get(BULK_FIELD));
                    i++;
                }
            }
            arrayList.add(textVectorDocument);
        }
        return arrayList;
    }

    private HttpEntity buildEntity(List<String> list) throws JsonProcessingException {
        HashMap hashMap = new HashMap();
        hashMap.put("id", "");
        hashMap.put("texts", list);
        hashMap.put("is_tokenized", false);
        return new StringEntity(this.objectMapper.writeValueAsString(hashMap), "UTF-8");
    }
}
