package org.languagetool.rules;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Streams;
import java.io.IOException;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import javax.net.ssl.SSLException;
import org.apache.commons.lang3.tuple.Pair;
import org.jetbrains.annotations.Nullable;
import org.languagetool.AnalyzedSentence;
import org.languagetool.Language;
import org.languagetool.languagemodel.bert.RemoteLanguageModel;
import org.languagetool.rules.RemoteRule;
import org.languagetool.rules.SuggestedReplacement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/languagetool/rules/BERTSuggestionRanking.class */
public class BERTSuggestionRanking extends RemoteRule {
    public static final String RULE_ID = "BERT_SUGGESTION_RANKING";
    private static final Logger logger = LoggerFactory.getLogger(BERTSuggestionRanking.class);
    private static final LoadingCache<RemoteRuleConfig, RemoteLanguageModel> models = CacheBuilder.newBuilder().build(CacheLoader.from(remoteRuleConfig -> {
        try {
            return new RemoteLanguageModel(remoteRuleConfig.getUrl(), remoteRuleConfig.getPort(), Boolean.parseBoolean(remoteRuleConfig.getOptions().getOrDefault("secure", "false")), remoteRuleConfig.getOptions().get("clientKey"), remoteRuleConfig.getOptions().get("clientCertificate"), remoteRuleConfig.getOptions().get("rootCertificate"));
        } catch (SSLException e) {
            throw new RuntimeException(e);
        }
    }));
    protected int suggestionLimit;
    private final RemoteLanguageModel model;
    private final Rule wrappedRule;
    private static final int MIN_WORDS = 8;
    private static final double MAX_ERROR_RATE = 0.5d;

    /* loaded from: input_file:org/languagetool/rules/BERTSuggestionRanking$CuratedAndSameCaseComparator.class */
    private static class CuratedAndSameCaseComparator implements Comparator<Pair<SuggestedReplacement, Double>> {
        private final String userWord;

        CuratedAndSameCaseComparator(String str) {
            this.userWord = str;
        }

        @Override // java.util.Comparator
        public int compare(Pair<SuggestedReplacement, Double> pair, Pair<SuggestedReplacement, Double> pair2) {
            if (((SuggestedReplacement) pair.getKey()).getReplacement().equalsIgnoreCase(this.userWord)) {
                return -1;
            }
            if (((SuggestedReplacement) pair2.getKey()).getReplacement().equalsIgnoreCase(this.userWord)) {
                return 1;
            }
            if (((SuggestedReplacement) pair.getKey()).getType() == ((SuggestedReplacement) pair2.getKey()).getType()) {
                return ((Double) pair2.getRight()).compareTo((Double) pair.getRight());
            }
            if (((SuggestedReplacement) pair.getKey()).getType() == SuggestedReplacement.SuggestionType.Curated) {
                return -1;
            }
            if (((SuggestedReplacement) pair2.getKey()).getType() == SuggestedReplacement.SuggestionType.Curated) {
                return 1;
            }
            return ((Double) pair2.getRight()).compareTo((Double) pair.getRight());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/languagetool/rules/BERTSuggestionRanking$MatchesForReordering.class */
    public class MatchesForReordering extends RemoteRule.RemoteRequest {
        final List<AnalyzedSentence> sentences;
        final List<RuleMatch> matches;
        final List<RemoteLanguageModel.Request> requests;

        MatchesForReordering(List<AnalyzedSentence> list, List<RuleMatch> list2, List<RemoteLanguageModel.Request> list3) {
            this.sentences = list;
            this.matches = list2;
            this.requests = list3;
        }
    }

    public BERTSuggestionRanking(Language language, Rule rule, RemoteRuleConfig remoteRuleConfig, boolean z) {
        super(language, rule.messages, remoteRuleConfig, z, rule.getId());
        this.suggestionLimit = 10;
        this.wrappedRule = rule;
        super.setCategory(this.wrappedRule.getCategory());
        synchronized (models) {
            RemoteLanguageModel remoteLanguageModel = null;
            try {
                remoteLanguageModel = (RemoteLanguageModel) models.get(this.serviceConfiguration);
            } catch (Exception e) {
                logger.error("Could not connect to BERT service at " + this.serviceConfiguration + " for suggestion reranking", e);
            }
            this.model = remoteLanguageModel;
        }
    }

    protected List<SuggestedReplacement> prepareSuggestions(List<SuggestedReplacement> list) {
        if (list.stream().anyMatch(suggestedReplacement -> {
            return suggestedReplacement.getType() == SuggestedReplacement.SuggestionType.Translation;
        })) {
            this.suggestionLimit = 25;
        } else {
            this.suggestionLimit = 10;
        }
        return list.subList(0, Math.min(list.size(), this.suggestionLimit));
    }

    @Override // org.languagetool.rules.RemoteRule
    protected RemoteRule.RemoteRequest prepareRequest(List<AnalyzedSentence> list, Long l) {
        LinkedList<RuleMatch> linkedList = new LinkedList();
        int i = 0;
        try {
            for (AnalyzedSentence analyzedSentence : list) {
                RuleMatch[] match = this.wrappedRule.match(analyzedSentence);
                Collections.addAll(linkedList, match);
                int length = analyzedSentence.getTokensWithoutWhitespace().length;
                i += length;
                if (length > 8 && match.length / length > 0.5d) {
                    for (RuleMatch ruleMatch : match) {
                        ruleMatch.discardLazySuggestedReplacements();
                    }
                    logger.info("Skipping suggestion generation for sentence, too many matches ({} matches in {} words)", Integer.valueOf(match.length), Integer.valueOf(length));
                }
            }
            if (i > 8 && linkedList.size() / i > 0.5d) {
                logger.info("Skipping suggestion generation for request, too many matches ({} matches in {} words)", Integer.valueOf(linkedList.size()), Integer.valueOf(i));
                linkedList.forEach((v0) -> {
                    v0.discardLazySuggestedReplacements();
                });
                return new MatchesForReordering(list, linkedList, Collections.emptyList());
            }
            LinkedList linkedList2 = new LinkedList();
            for (RuleMatch ruleMatch2 : linkedList) {
                ruleMatch2.setSuggestedReplacementObjects(prepareSuggestions(ruleMatch2.getSuggestedReplacementObjects()));
                linkedList2.add(buildRequest(ruleMatch2));
            }
            return new MatchesForReordering(list, linkedList, linkedList2);
        } catch (IOException e) {
            logger.error("Error while executing rule " + this.wrappedRule.getId(), e);
            return new MatchesForReordering(list, Collections.emptyList(), Collections.emptyList());
        }
    }

    @Override // org.languagetool.rules.RemoteRule
    protected RemoteRuleResult fallbackResults(RemoteRule.RemoteRequest remoteRequest) {
        MatchesForReordering matchesForReordering = (MatchesForReordering) remoteRequest;
        return new RemoteRuleResult(false, false, matchesForReordering.matches, matchesForReordering.sentences);
    }

    @Override // org.languagetool.rules.RemoteRule
    protected Callable<RemoteRuleResult> executeRequest(RemoteRule.RemoteRequest remoteRequest, long j) throws TimeoutException {
        return () -> {
            if (this.model == null) {
                return fallbackResults(remoteRequest);
            }
            MatchesForReordering matchesForReordering = (MatchesForReordering) remoteRequest;
            List<RuleMatch> list = matchesForReordering.matches;
            List<RemoteLanguageModel.Request> list2 = matchesForReordering.requests;
            List list3 = (List) Streams.mapWithIndex(list2.stream(), (request, j2) -> {
                if (request != null) {
                    return Long.valueOf(j2);
                }
                return null;
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            }).collect(Collectors.toList());
            List<RemoteLanguageModel.Request> list4 = (List) list2.stream().filter((v0) -> {
                return Objects.nonNull(v0);
            }).collect(Collectors.toList());
            if (list4.isEmpty()) {
                return new RemoteRuleResult(false, true, list, matchesForReordering.sentences);
            }
            List<List<Double>> batchScore = this.model.batchScore(list4, j);
            for (int i = 0; i < list3.size(); i++) {
                List<Double> list5 = batchScore.get(i);
                String substring = list4.get(i).text.substring(list4.get(i).start, list4.get(i).end);
                RuleMatch ruleMatch = list.get(((Long) list3.get(i)).intValue());
                ruleMatch.setSuggestedReplacementObjects((List) Streams.zip(ruleMatch.getSuggestedReplacementObjects().stream(), list5.stream(), (v0, v1) -> {
                    return Pair.of(v0, v1);
                }).sorted(new CuratedAndSameCaseComparator(substring)).map((v0) -> {
                    return v0.getLeft();
                }).collect(Collectors.toList()));
            }
            return new RemoteRuleResult(true, true, list, matchesForReordering.sentences);
        };
    }

    @Nullable
    private RemoteLanguageModel.Request buildRequest(RuleMatch ruleMatch) {
        List<String> suggestedReplacements = ruleMatch.getSuggestedReplacements();
        if (suggestedReplacements == null || suggestedReplacements.size() <= 1) {
            return null;
        }
        return new RemoteLanguageModel.Request(ruleMatch.getSentence().getText(), ruleMatch.getFromPos(), ruleMatch.getToPos(), suggestedReplacements);
    }

    @Override // org.languagetool.rules.RemoteRule, org.languagetool.rules.Rule
    public String getId() {
        return this.wrappedRule.getId();
    }

    @Override // org.languagetool.rules.Rule
    public String getDescription() {
        return this.wrappedRule.getDescription();
    }

    static {
        shutdownRoutines.add(() -> {
            models.asMap().values().forEach((v0) -> {
                v0.shutdown();
            });
        });
    }
}
