/*
 * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */
package org.graalvm.compiler.hotspot;

import static org.graalvm.compiler.hotspot.ProfileReplaySupport.Options.LoadProfiles;
import static org.graalvm.compiler.hotspot.ProfileReplaySupport.Options.ProfileMethodFilter;
import static org.graalvm.compiler.hotspot.ProfileReplaySupport.Options.SaveProfiles;
import static org.graalvm.compiler.hotspot.ProfileReplaySupport.Options.StrictProfiles;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.graalvm.collections.EconomicMap;
import org.graalvm.compiler.code.CompilationResult;
import org.graalvm.compiler.core.common.CompilationIdentifier;
import org.graalvm.compiler.core.common.cfg.BasicBlock;
import org.graalvm.compiler.debug.DebugContext;
import org.graalvm.compiler.debug.GraalError;
import org.graalvm.compiler.debug.MethodFilter;
import org.graalvm.compiler.debug.PathUtilities;
import org.graalvm.compiler.debug.TTY;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.NodeMap;
import org.graalvm.compiler.hotspot.CompilationTask.HotSpotCompilationWrapper;
import org.graalvm.compiler.java.LambdaUtils;
import org.graalvm.compiler.java.StableMethodNameFormatter;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.FrameState;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.java.AccessFieldNode;
import org.graalvm.compiler.nodes.spi.StableProfileProvider;
import org.graalvm.compiler.nodes.spi.StableProfileProvider.LambdaNameFormatter;
import org.graalvm.compiler.nodes.spi.StableProfileProvider.TypeFilter;
import org.graalvm.compiler.options.Option;
import org.graalvm.compiler.options.OptionKey;
import org.graalvm.compiler.options.OptionType;
import org.graalvm.compiler.phases.schedule.SchedulePhase;
import org.graalvm.util.json.JSONParser;

import jdk.vm.ci.hotspot.HotSpotResolvedJavaMethod;
import jdk.vm.ci.meta.ResolvedJavaMethod;

/**
 * Support class encapsulating profile replay support. Contains functionality to save, load and
 * verify loaded profiles.
 */
public final class ProfileReplaySupport {

    public static class Options {
        // @formatter:off
        @Option(help = "Save per compilation profile information.", type = OptionType.User)
        public static final OptionKey<Boolean> SaveProfiles = new OptionKey<>(false);
        @Option(help = "Allow multiple compilations of the same method by overriding existing profiles.", type = OptionType.User)
        public static final OptionKey<Boolean> OverrideProfiles = new OptionKey<>(false);
        @Option(help = "Path for saving compilation profiles. "
                        + "If the value is omitted the debug dump path will be used.", type = OptionType.User)
        public static final OptionKey<String> SaveProfilesPath = new OptionKey<>(null);
        @Option(help = "Load per compilation profile information.", type = OptionType.User)
        public static final OptionKey<String> LoadProfiles = new OptionKey<>(null);
        @Option(help = "Restrict saving or loading of profiles based on this filter. "
                        + "See the MethodFilter option for the pattern syntax.", type = OptionType.User)
        public static final OptionKey<String> ProfileMethodFilter = new OptionKey<>(null);
        @Option(help = "Throw an error if an attempt is made to overwrite/update a profile loaded from disk.", type = OptionType.User)
        public static final OptionKey<Boolean> StrictProfiles = new OptionKey<>(true);
        @Option(help = "Print to stdout when a profile is loaded.", type = OptionType.User)
        public static final OptionKey<Boolean> PrintProfileLoading = new OptionKey<>(false);
        @Option(help = "Print to stdout when a compilation performed with different profiles generates different "
                        + "frontend IR.", type = OptionType.User)
        public static final OptionKey<Boolean> WarnAboutGraphSignatureMismatch = new OptionKey<>(true);
        @Option(help = "Print to stdout when a compilation performed with different profiles generates different "
                        + "backend code.", type = OptionType.User)
        public static final OptionKey<Boolean> WarnAboutCodeSignatureMismatch = new OptionKey<>(true);
        @Option(help = "Print to stdout when requesting profiling info not present in a loaded profile.", type = OptionType.User)
        public static final OptionKey<Boolean> WarnAboutNotCachedLoadedAccess = new OptionKey<>(true);
        // @formatter:on
    }

    private final LambdaNameFormatter lambdaNameFormatter;
    /**
     * Tri-state capturing the expected result of the compilation. Potential values are
     * {@code null,True,False}.
     *
     * If we are running a regular compilation without loading profiles it will always be
     * {@code null}.
     *
     * If we are running a profile replay compilation this field will contain the result of reading
     * the entry of the profile file. The original value written represents the expression
     * {@code originalCodeResult!=null}. This means if the original compile produced a result
     * without errors it will be set to {@code True}, else {@code False}.
     */
    private final Boolean expectedResult;
    private final String expectedCodeSignature;
    private final String expectedGraphSignature;
    private final MethodFilter profileFilter;
    private final TypeFilter profileSaveFilter;

    private ProfileReplaySupport(LambdaNameFormatter lambdaNameFormatter, Boolean expectedResult, String expectedCodeSignature, String expectedGraphSignature, MethodFilter profileFilter,
                    TypeFilter profileSaveFilter) {
        this.lambdaNameFormatter = lambdaNameFormatter;
        this.expectedResult = expectedResult;
        this.expectedCodeSignature = expectedCodeSignature;
        this.expectedGraphSignature = expectedGraphSignature;
        this.profileFilter = profileFilter;
        this.profileSaveFilter = profileSaveFilter;
    }

    public Boolean getExpectedResult() {
        return expectedResult;
    }

    public static ProfileReplaySupport profileReplayPrologue(DebugContext debug, HotSpotGraalRuntimeProvider graalRuntime, int entryBCI, HotSpotResolvedJavaMethod method,
                    StableProfileProvider profileProvider, TypeFilter profileSaveFilter) {
        if (SaveProfiles.getValue(debug.getOptions()) || LoadProfiles.getValue(debug.getOptions()) != null) {
            LambdaNameFormatter lambdaNameFormatter = new LambdaNameFormatter() {
                private final StableMethodNameFormatter stableFormatter = new StableMethodNameFormatter(graalRuntime.getHostBackend().getProviders(), debug, true);

                @Override
                public boolean isLambda(ResolvedJavaMethod m) {
                    // Include method handles here as well
                    return LambdaUtils.isLambdaType(m.getDeclaringClass()) || StableMethodNameFormatter.isMethodHandle(m.getDeclaringClass());
                }

                @Override
                public String formatLamdaName(ResolvedJavaMethod m) {
                    return stableFormatter.apply(m);
                }
            };
            Boolean expectedResult = null;
            String expectedCodeSignature = null;
            String expectedGraphSignature = null;
            MethodFilter profileFilter = null;
            String filterString = ProfileMethodFilter.getValue(debug.getOptions());
            profileFilter = filterString == null || filterString.isEmpty() ? MethodFilter.matchAll() : MethodFilter.parse(filterString);
            if (LoadProfiles.getValue(debug.getOptions()) != null && profileFilter.matches(method)) {
                Path loadDir = Paths.get(LoadProfiles.getValue(debug.getOptions()));
                try (Stream<Path> files = Files.list(loadDir)) {
                    String s = PathUtilities.sanitizeFileName(method.format("%h.%n(%p)%r"));
                    boolean foundOne = false;
                    for (Path path : files.filter(x -> x.toString().contains(s)).filter(x -> x.toString().endsWith(".glog")).collect(Collectors.toList())) {
                        EconomicMap<String, Object> map = JSONParser.parseDict(new FileReader(path.toFile()));
                        if (entryBCI == (int) map.get("entryBCI")) {
                            foundOne = true;
                            expectedResult = (Boolean) map.get("result");
                            expectedCodeSignature = (String) map.get("codeSignature");
                            expectedGraphSignature = (String) map.get("graphSignature");
                            profileProvider.load(map, method.getDeclaringClass(), Options.WarnAboutNotCachedLoadedAccess.getValue(debug.getOptions()), lambdaNameFormatter);
                            if (StrictProfiles.getValue(debug.getOptions())) {
                                profileProvider.freeze();
                            }
                            if (Options.PrintProfileLoading.getValue(debug.getOptions())) {
                                TTY.println("Loaded profile data from " + path);
                            }
                            break;

                        }
                    }
                    if (Options.StrictProfiles.getValue(debug.getOptions()) && !foundOne) {
                        throw GraalError.shouldNotReachHere(String.format("No file for method %s found in %s, strict profiles, abort", s, loadDir)); // ExcludeFromJacocoGeneratedReport
                    }
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            return new ProfileReplaySupport(lambdaNameFormatter, expectedResult, expectedCodeSignature, expectedGraphSignature, profileFilter, profileSaveFilter);
        }
        return null;
    }

    public void profileReplayEpilogue(DebugContext debug, HotSpotCompilationWrapper compilation, StableProfileProvider profileProvider, CompilationIdentifier compilationId, int entryBCI,
                    HotSpotResolvedJavaMethod method) {
        if ((SaveProfiles.getValue(debug.getOptions()) || LoadProfiles.getValue(debug.getOptions()) != null) && profileFilter.matches(method)) {
            String codeSignature = null;
            String graphSignature = null;
            if (compilation.result != null) {
                codeSignature = compilation.result.getCodeSignature();
                assert compilation.graph != null;
                String s = getCanonicalGraphString(compilation.graph);
                graphSignature = CompilationResult.getSignature(s.getBytes(StandardCharsets.UTF_8));
            }
            if (Options.WarnAboutCodeSignatureMismatch.getValue(debug.getOptions())) {
                if (expectedCodeSignature != null && !Objects.equals(codeSignature, expectedCodeSignature)) {
                    TTY.printf("%s %s codeSignature differs %s != %s%n", method.format("%H.%n(%P)%R"), entryBCI, codeSignature, expectedCodeSignature);
                }
            }
            if (Options.WarnAboutGraphSignatureMismatch.getValue(debug.getOptions())) {
                if (expectedGraphSignature != null && !Objects.equals(graphSignature, expectedGraphSignature)) {
                    TTY.printf("%s %s graphSignature differs %s != %s%n", method.format("%H.%n(%P)%R"), entryBCI, graphSignature, expectedGraphSignature);
                }
            }
            if (SaveProfiles.getValue(debug.getOptions())) {
                try {
                    EconomicMap<String, Object> map = EconomicMap.create();
                    map.put("identifier", compilationId.toString());
                    map.put("method", method.format("%H.%n(%P)%R"));
                    map.put("entryBCI", entryBCI);
                    map.put("codeSignature", codeSignature);
                    map.put("graphSignature", graphSignature);
                    map.put("result", compilation.result != null);
                    profileProvider.recordProfiles(map, profileSaveFilter, lambdaNameFormatter);
                    String path = null;
                    if (Options.SaveProfilesPath.getValue(debug.getOptions()) != null) {
                        String fileName = PathUtilities.sanitizeFileName(method.format("%h.%n(%p)%r") + ".glog");
                        String dirName = Options.SaveProfilesPath.getValue(debug.getOptions());
                        path = Paths.get(dirName).resolve(fileName).toString();
                        if (new File(path).exists() && !Options.OverrideProfiles.getValue(debug.getOptions())) {
                            throw new InternalError("Profile file for path " + path + " exists already");
                        }
                    } else {
                        path = debug.getDumpPath(".glog", false, false);
                    }
                    try (PrintStream out = new PrintStream(new BufferedOutputStream(PathUtilities.openOutputStream(path)))) {
                        out.println(org.graalvm.util.json.JSONFormatter.formatJSON(map, true));
                    }
                } catch (Throwable t) {
                    throw debug.handle(t);
                }
            }
        }
    }

    private static String getCanonicalGraphString(StructuredGraph graph) {
        SchedulePhase.runWithoutContextOptimizations(graph, SchedulePhase.SchedulingStrategy.EARLIEST);
        StructuredGraph.ScheduleResult scheduleResult = graph.getLastSchedule();
        NodeMap<Integer> canonicalId = graph.createNodeMap();
        int nextId = 0;
        StringBuilder result = new StringBuilder();
        for (BasicBlock<?> block : scheduleResult.getCFG().getBlocks()) {
            result.append("Block ").append(block).append(' ');
            if (block == scheduleResult.getCFG().getStartBlock()) {
                result.append("* ");
            }
            result.append("-> ");
            for (int i = 0; i < block.getSuccessorCount(); i++) {
                BasicBlock<?> succ = block.getSuccessorAt(i);
                result.append(succ).append(' ');
            }
            result.append(String.format("%n"));
            for (Node node : scheduleResult.getBlockToNodesMap().get(block)) {
                if (node instanceof ValueNode && node.isAlive()) {
                    if (!(node instanceof ConstantNode)) {
                        int id;
                        if (canonicalId.get(node) != null) {
                            id = canonicalId.get(node);
                        } else {
                            id = nextId++;
                            canonicalId.set(node, id);
                        }
                        String name = node.getClass().getSimpleName();
                        result.append("  ").append(id).append('|').append(name);
                        if (node instanceof AccessFieldNode) {
                            result.append('#');
                            result.append(((AccessFieldNode) node).field());
                        }
                        result.append("    (");
                        result.append(node.usages().filter(n -> !(n instanceof FrameState)).count());
                        result.append(')');
                        result.append(String.format("%n"));
                    }
                }
            }
        }
        return result.toString();
    }

}
