import styled from "styled-components";
import { useState } from "react";
import { round } from "lodash";

import ExpandableOnTitleClick from "components/widgets/ExpandableOnTitleClick";
import Modal from "components/ui/Modal";
import CurveAreaChart from "components/charts/CurveAreaChart";
import { lightTheme as theme } from "App";
import TextInput from "components/ui/TextInput";
import { Gap } from "components/Layout";

const getPlotAxisKeysAndData = (metricTitle, curveData) => {
  if (metricTitle === "rocCurve") {
    return {
      xAxisKey: "False Positive Rate",
      yAxisKey: "True Positive Rate",
      data: curveData.map(([fpr, tpr]) => ({
        "False Positive Rate": fpr,
        "True Positive Rate": tpr,
      })),
    };
  }
  return {
    xAxisKey: "Recall",
    yAxisKey: "Precision",
    data: curveData.map(([recall, precision]) => ({
      Recall: recall,
      Precision: precision,
    })),
  };
};

const getAccCol = props => {
  if (props.colAcc < 0.5) {
    return `rgba(255, 0, 0, ${(0.5 - props.colAcc) * 0.7})`;
  }
  return `rgba(0, 255, 0, ${(props.colAcc - 0.5) * 0.7})`;
};

const ColumnDetails = styled.div`
  background: linear-gradient(${getAccCol}, ${getAccCol}),
    linear-gradient(${props => props.theme.color.closer1}, ${props => props.theme.color.closer1});
  width: calc(33% - 5px);
  background-color: ${props => props.theme.color.closer1};
  margin-bottom: 10px;
  border-radius: 5px;
  padding: 10px;
`;

const ColumnName = styled.div`
  font-size: 18px;
  text-transform: uppercase;
  padding-bottom: 10px;
`;

const ColumnsContainer = styled.div`
  display: flex;
  gap: 10px;
  flex-wrap: wrap;
`;

const PlotContainer = styled.div`
  padding: 10px;
  border-radius: 5px;
  background-color: ${props => props.theme.color.furthest};
  width: 600px;
`;

const ShowCurveLink = styled.td`
  cursor: pointer;
  text-decoration: underline;
  color: ${props => props.theme.color.primary};
  width: max-content;
`;

const CurveMetricRow = ({ title, curveData }) => {
  const [isPlotModalOpen, setIsPlotModalOpen] = useState(false);

  const { xAxisKey, yAxisKey, data } = getPlotAxisKeysAndData(title, curveData);

  return (
    <tr>
      <MetricName>{metricNamesToMetricLabels[title] || title}</MetricName>
      <ShowCurveLink onClick={() => setIsPlotModalOpen(true)}>Show</ShowCurveLink>
      <Modal
        title={title === "prCurve" ? "Precision-Recall" : "ROC"}
        open={isPlotModalOpen}
        handleClose={() => setIsPlotModalOpen(false)}
      >
        <PlotContainer>
          <CurveAreaChart
            data={data}
            xAxisKey={xAxisKey}
            yAxisKey={yAxisKey}
            chartSize={{ height: 300 }}
            color={theme.color.primary}
          />
        </PlotContainer>
      </Modal>
    </tr>
  );
};

const MetricName = styled.td`
  font-weight: bold;
  padding-bottom: 5px;
`;

const MetricValue = styled.td``;

const TextInputContainer = styled.div`
  display: flex;
  align-items: center;
  gap: 10px;
  font-weight: bold;
  padding: 10px;
  width: 200px;
`;

const metricNamesToMetricLabels = {
  type: "Type",
  averageAcccuracy: "Average accuracy",
  rocCurve: "ROC curve",
  rocAuc: "ROC AUC",
  averageError: "Average error",
};

const metricsToIgnore = ["averageLogLoss", "totalQueries"];

const ExpandableTaskEvaluation = ({ taskEvaluation }) => {
  const columnNames = Object.keys(taskEvaluation.columnBasedScores);

  const [filterBy, setFilterBy] = useState("");

  return (
    <ExpandableOnTitleClick title={taskEvaluation.taskName} isInitiallyExpanded={false}>
      {columnNames.length > 3 && (
        <TextInputContainer>
          <span>Filter: </span>
          <TextInput value={filterBy} onNewInput={newFilterBy => setFilterBy(newFilterBy)} />
        </TextInputContainer>
      )}
      <ColumnsContainer>
        {columnNames
          .filter(colName => colName.toLowerCase().includes(filterBy.toLowerCase()))
          .sort((colName1, colName2) => {
            const colAcc1 =
              taskEvaluation.columnBasedScores[colName1]?.averageAcccuracy ||
              taskEvaluation.columnBasedScores[colName1]?.r2Score;
            const colAcc2 =
              taskEvaluation.columnBasedScores[colName2]?.averageAcccuracy ||
              taskEvaluation.columnBasedScores[colName2]?.r2Score;
            return colAcc2 - colAcc1;
          })
          .map(columnName => {
            const columnMetrics = taskEvaluation.columnBasedScores[columnName];
            const metricTitles = Object.keys(columnMetrics);

            const colAcc = columnMetrics?.averageAcccuracy || columnMetrics?.r2Score;
            return (
              <ColumnDetails colAcc={colAcc} key={columnName}>
                <ColumnName>{columnName}</ColumnName>
                <table>
                  <tbody>
                    {metricTitles
                      .filter(titleOfMetric => !metricsToIgnore.includes(titleOfMetric))
                      .map(titleOfMetric =>
                        titleOfMetric.includes("Curve") ? (
                          <CurveMetricRow
                            key={titleOfMetric}
                            title={titleOfMetric}
                            curveData={columnMetrics[titleOfMetric]}
                          />
                        ) : (
                          <tr key={titleOfMetric}>
                            <MetricName>{metricNamesToMetricLabels[titleOfMetric] || titleOfMetric}</MetricName>
                            <MetricValue>
                              {typeof columnMetrics[titleOfMetric] === "number"
                                ? round(columnMetrics[titleOfMetric], 3)
                                : columnMetrics[titleOfMetric]}
                            </MetricValue>
                          </tr>
                        )
                      )}
                  </tbody>
                </table>
              </ColumnDetails>
            );
          })}
      </ColumnsContainer>
      <Gap height="40px" />
    </ExpandableOnTitleClick>
  );
};

export default ExpandableTaskEvaluation;
