import styled from "styled-components";
import { useState } from "react";
import { round } from "lodash";

import ExpandableOnTitleClick from "components/widgets/ExpandableOnTitleClick";
import Modal from "components/ui/Modal";
import AreaChart from "components/charts/AreaChart";
import { lightTheme as theme } from "App";
import { Gap } from "components/Layout";
import TextInput from "components/ui/TextInput";

const getPlotAxisKeysAndData = (metricTitle, curveData) => {
  if (!curveData.map) {
    return { xAxisKey: "-", yAxisKey: "-", data: [] };
  }

  if (metricTitle === "rocCurve") {
    return {
      xAxisKey: "False Positive Rate",
      yAxisKey: "True Positive Rate",
      data: curveData.map(([fpr, tpr]) => ({
        "False Positive Rate": fpr,
        "True Positive Rate": tpr,
      })),
    };
  }
  return {
    xAxisKey: "Recall",
    yAxisKey: "Precision",
    data: curveData.map(([recall, precision]) => ({
      Recall: recall,
      Precision: precision,
    })),
  };
};

const ColumnDetails = styled.div`
  min-width: calc(50% - 5px);
  background-color: ${props => props.theme.color.closer1};
  margin-bottom: 10px;
  border-radius: 5px;
  padding: 10px;
`;

const ColumnName = styled.div`
  font-size: 18px;
  text-transform: uppercase;
  padding-bottom: 10px;
`;

const ColumnsContainer = styled.div`
  display: flex;
  gap: 10px;
  flex-wrap: wrap;
`;

const PlotContainer = styled.div`
  padding: 10px;
  border-radius: 5px;
  background-color: ${props => props.theme.color.furthest};
  width: 600px;
`;

const ShowCurveLink = styled.a`
  cursor: pointer;
  text-decoration: underline;
  color: ${props => props.theme.color.primary};
  width: max-content;
`;

const CurveMetricTd = ({ title, curveData }) => {
  const [isPlotModalOpen, setIsPlotModalOpen] = useState(false);

  const { xAxisKey, yAxisKey, data } = getPlotAxisKeysAndData(title, curveData);

  if (data.length === 0) {
    return "-";
  }

  return (
    <td>
      <ShowCurveLink onClick={() => setIsPlotModalOpen(true)}>Show</ShowCurveLink>
      <Modal
        title={title === "prCurve" ? "Precision-Recall" : "ROC"}
        open={isPlotModalOpen}
        handleClose={() => setIsPlotModalOpen(false)}
      >
        <PlotContainer>
          <AreaChart
            data={data}
            xAxisKey={xAxisKey}
            yAxisKey={yAxisKey}
            chartSize={{ height: 300 }}
            color={theme.color.primary}
          />
        </PlotContainer>
      </Modal>
    </td>
  );
};

const MetricName = styled.td`
  font-weight: bold;
  padding-bottom: 5px;
`;

const PipelineId = styled.td`
  font-weight: bold;
  padding-bottom: 10px;
`;

const MetricValue = styled.td``;

const getMetricValueForPipelineTaskIdAndColumnName = (pipeline, columnName, taskId, titleOfMetric, evaluationJobs) => {
  const evalJob = evaluationJobs.find(job => pipeline.evaluationJob.id === job.id);
  if (!evalJob) {
    return "-";
  }
  const taskEvaluation = evalJob?.result?.taskEvaluations?.find(task => task.taskId === taskId);
  if (!taskEvaluation) {
    return "-";
  }

  const val = taskEvaluation.columnBasedScores[columnName][titleOfMetric];
  if (typeof val === "number") {
    return round(val, 3);
  }
  return val;
};

const getMetricValueForEvalJob = (columnName, taskId, titleOfMetric, evalJob) => {
  const taskEvaluation = evalJob?.result?.taskEvaluations?.find(task => task.taskId === taskId);
  if (!taskEvaluation) {
    return "-";
  }

  const val = taskEvaluation.columnBasedScores[columnName][titleOfMetric];
  if (typeof val === "number") {
    return round(val, 3);
  }
  return val;
};

const TextInputContainer = styled.div`
  display: flex;
  align-items: center;
  gap: 10px;
  font-weight: bold;
  padding: 10px;
  width: 200px;
`;

const metricNamesToMetricLabels = {
  type: "Type",
  totalQueries: "Number of instances",
  averageLogLoss: "Average log loss",
  averageAcccuracy: "Average accuracy",
  rocCurve: "ROC curve",
  rocAuc: "ROC AUC",
  averageError: "Average error",
};

const getFormattedPipelineId = pipeline => `${pipeline.pipelineId.slice(0, 3)}...${pipeline.pipelineId.slice(-3)}`;

const MultiExpandableTaskEvaluation = ({
  taskEvaluation,
  pipelineOutputs,
  commitShas,
  evaluationJobs,
  usingEvalJobIds = false,
}) => {
  const columnNames = Object.keys(taskEvaluation.columnBasedScores);

  const [filterBy, setFilterBy] = useState("");

  const idIndicators = usingEvalJobIds
    ? evaluationJobs?.map(evalJob => (
        <PipelineId key={`${evalJob.id}`}>{`${evalJob.id.slice(0, 3)}...${evalJob.id.slice(-3)}`}</PipelineId>
      ))
    : [
        commitShas
          ? commitShas?.map(sha => <PipelineId key={`${sha}`}>{`${sha.slice(0, 3)}...${sha.slice(-3)}`}</PipelineId>)
          : pipelineOutputs?.map(p => <PipelineId key={`${p.pipelineId}`}>{getFormattedPipelineId(p)}</PipelineId>),
      ];

  return (
    <ExpandableOnTitleClick title={taskEvaluation.taskName} isInitiallyExpanded={false}>
      {columnNames.length > 3 && (
        <TextInputContainer>
          <span>Filter: </span>
          <TextInput onNewInput={newFilterBy => setFilterBy(newFilterBy)} />
        </TextInputContainer>
      )}
      <ColumnsContainer>
        {columnNames
          .filter(colName => colName.toLowerCase().includes(filterBy.toLowerCase()))
          .map(columnName => {
            const columnMetrics = taskEvaluation.columnBasedScores[columnName];
            const metricTitles = Object.keys(columnMetrics);

            return (
              <ColumnDetails key={columnName}>
                <ColumnName>{columnName}</ColumnName>
                <table>
                  <tbody>
                    <tr>
                      <td></td>
                      {idIndicators}
                    </tr>
                    {metricTitles.map(titleOfMetric =>
                      titleOfMetric.includes("Curve") ? (
                        <tr key={titleOfMetric}>
                          <MetricName>{metricNamesToMetricLabels[titleOfMetric] || titleOfMetric}</MetricName>
                          {usingEvalJobIds &&
                            evaluationJobs.map(evalJob => (
                              <CurveMetricTd
                              key={`${titleOfMetric}-`}
                              title={titleOfMetric}
                              curveData={getMetricValueForEvalJob(columnName, taskEvaluation.taskId, titleOfMetric, evalJob)}
                            />))}
                          {pipelineOutputs?.map(pipeline => (
                            <CurveMetricTd
                              key={`${titleOfMetric}-${pipeline.pipelineId}`}
                              title={titleOfMetric}
                              curveData={getMetricValueForPipelineTaskIdAndColumnName(
                                pipeline,
                                columnName,
                                taskEvaluation.taskId,
                                titleOfMetric,
                                evaluationJobs
                              )}
                            />
                          ))}
                        </tr>
                      ) : (
                        <tr key={titleOfMetric}>
                          <MetricName>{metricNamesToMetricLabels[titleOfMetric] || titleOfMetric}</MetricName>
                          {usingEvalJobIds &&
                            evaluationJobs.map(evalJob => (
                              <MetricValue key={evalJob.id}>
                                {getMetricValueForEvalJob(columnName, taskEvaluation.taskId, titleOfMetric, evalJob)}
                              </MetricValue>
                            ))}

                          {pipelineOutputs?.map(pipeline => (
                            <MetricValue key={`${pipeline.pipelineId}`}>
                              {getMetricValueForPipelineTaskIdAndColumnName(
                                pipeline,
                                columnName,
                                taskEvaluation.taskId,
                                titleOfMetric,
                                evaluationJobs
                              )}
                            </MetricValue>
                          ))}
                        </tr>
                      )
                    )}
                  </tbody>
                </table>
              </ColumnDetails>
            );
          })}
        <Gap height="40px" />
      </ColumnsContainer>
    </ExpandableOnTitleClick>
  );
};

export default MultiExpandableTaskEvaluation;
