import dagre from "dagre";

const isOutputNode = (nodeId, edges) => {
  const nodeIdsWithForwardConnections = edges.map(edge => edge.source);
  return !nodeIdsWithForwardConnections.includes(nodeId);
};

const getDagreLayoutedNodes = (nodes, edges) => {
  const dagreGraph = new dagre.graphlib.Graph();

  dagreGraph.setDefaultEdgeLabel(() => ({}));
  dagreGraph.setGraph({ rankdir: "LR" });

  nodes.forEach(node => dagreGraph.setNode(node.id, { width: 120, height: 80 }));
  edges.forEach(edge => dagreGraph.setEdge(edge.source, edge.target));

  dagre.layout(dagreGraph);

  const layoutedNodes = nodes.map(node => {
    const { x, y } = dagreGraph.node(node.id);
    return { ...node, position: { x, y } };
  });

  return layoutedNodes;
};

const getDimsForNodeIndex = (nodeIndex, layerConfig) => {
  const { resIndices, inputDims, outputDims } = layerConfig?.find(edge => edge.resIndices.includes(nodeIndex)) || {};
  return resIndices?.indexOf(nodeIndex) === 0 ? inputDims : outputDims;
};

const getMinAndMaxDimInModel = layerConfig => {
  let [minModelDim, maxModelDim] = [0, 0];
  layerConfig.forEach(edge => {
    minModelDim = Math.min(minModelDim, ...edge.inputDims, ...edge.outputDims);
    maxModelDim = Math.max(maxModelDim, ...edge.inputDims, ...edge.outputDims);
  });

  return { minModelDim, maxModelDim };
};

const getLabelForNodeIndex = (nodeIndex, edges, featureTypeDescriptors) => {
  if (nodeIndex === 0) {
    return "Data";
  }

  if (!isOutputNode(String(nodeIndex), edges)) {
    const incomingEdge = edges.find(edge => edge.target === String(nodeIndex));
    return `${incomingEdge?.data?.type} layer`;
  }

  return featureTypeDescriptors?.[Math.abs(nodeIndex) - 1]?.key;
};

const isLeaf = (nodeId, layerConfig) => {
  let isLeaf = true;
  layerConfig.forEach(edge => {
    if (edge.resIndices?.[0] === nodeId) {
      isLeaf = false;
    }
  });
  return isLeaf;
};

const getColumnNameOfNode = (nodeId, featureTypeDescriptors) => {
  return featureTypeDescriptors?.[Math.abs(nodeId) - 1]?.key;
};

const getChildrenIds = (nodeId, layerConfig) => {
  return layerConfig.filter(edge => edge.resIndices?.[0] === nodeId).map(edge => edge.resIndices?.[1]);
};

const getNodeIdToTypeMap = (layerConfig, inputCols, outputCols, featureTypeDescriptors) => {
  const nodeIdToType = {};

  const fillTypeForNodeId = (nodeId, layerConfig, inputCols, outputCols, featureTypeDescriptors) => {
    if (isLeaf(nodeId, layerConfig)) {
      const columnNameOfNode = getColumnNameOfNode(nodeId, featureTypeDescriptors);

      let nodeType = null;
      if (inputCols.includes(columnNameOfNode)) {
        nodeType = "input";
      }
      if (outputCols.includes(columnNameOfNode)) {
        nodeType = "output";
      }
      if (inputCols.includes(columnNameOfNode) && outputCols.includes(columnNameOfNode)) {
        nodeType = "both";
      }
      nodeIdToType[nodeId] = nodeType;
      return;
    }

    const childrenIds = getChildrenIds(nodeId, layerConfig);
    childrenIds.forEach(childId =>
      fillTypeForNodeId(childId, layerConfig, inputCols, outputCols, featureTypeDescriptors)
    );

    let nodeType = null;
    if (childrenIds.every(childId => nodeIdToType[childId] === "input")) {
      nodeType = "input";
    }
    if (childrenIds.every(childId => nodeIdToType[childId] === "output")) {
      nodeType = "output";
    }
    if (childrenIds.some(childId => nodeIdToType[childId] === "both")) {
      nodeType = "both";
    }
    if (
      childrenIds.some(childId => nodeIdToType[childId] === "input") &&
      childrenIds.some(childId => nodeIdToType[childId] === "output")
    ) {
      nodeType = "both";
    }
    nodeIdToType[nodeId] = nodeType;
  };

  fillTypeForNodeId(0, layerConfig, inputCols, outputCols, featureTypeDescriptors);

  return nodeIdToType;
};

const getEdgeDirection = (targetId, nodeIdToType) => {
  const targetType = nodeIdToType[targetId];

  if (targetType === "input") {
    return "backward";
  }
  if (targetType === "output") {
    return "forward";
  }
  if (targetType === "both") {
    return "twoway";
  }
  return null;
};

export const getReactFlowElementsFromModelConfig = (
  modelConfig,
  featureTypeDescriptors,
  inputCols = [],
  outputCols = []
) => {
  const { layerConfig } = modelConfig;

  const nodeIdToType = getNodeIdToTypeMap(layerConfig, inputCols, outputCols, featureTypeDescriptors);

  const initialEdges = layerConfig.map(edge => {
    let direction = getEdgeDirection(edge.resIndices[1], nodeIdToType);
    if (inputCols.length === 0 && outputCols.length === 0) {
      direction = "twoway";
    }

    return {
      id: edge.resIndices.join("-"),
      source: String(edge.resIndices[0]),
      target: String(edge.resIndices[1]),
      data: {
        ...edge,
        direction,
      },
      type: "operationEdge",
    };
  });

  const edges = initialEdges.map(edge => {
    if (isOutputNode(edge.target, initialEdges)) {
      return { ...edge, target: "output" };
    }
    return edge;
  });

  const nodeIndices = [...new Set(layerConfig.map(edge => edge.resIndices).flat())];

  const { minModelDim, maxModelDim } = getMinAndMaxDimInModel(layerConfig);

  const initialNodes = nodeIndices
    .filter(nodeIndex => !isOutputNode(String(nodeIndex), edges))
    .map(nodeIndex => {
      let type = nodeIdToType[nodeIndex];
      if (inputCols.length === 0 && outputCols.length === 0) {
        type = "both";
      }

      return {
        id: String(nodeIndex),
        data: {
          label: getLabelForNodeIndex(nodeIndex, edges, featureTypeDescriptors),
          dims: getDimsForNodeIndex(nodeIndex, layerConfig),
          isOutputNode: isOutputNode(String(nodeIndex), edges),
          type, // input, output, both
          minModelDim,
          maxModelDim,
        },
        type: "dataNode",
      };
    });

  const nodes = [
    ...initialNodes,
    {
      id: "output",
      data: {
        label: "Data",
        dims: getDimsForNodeIndex(0, layerConfig),
        isOutputNode: true,
        minModelDim,
        maxModelDim,
      },
      type: "dataNode",
    },
  ];

  const layoutedNodes = getDagreLayoutedNodes(nodes, edges);

  return [...layoutedNodes, ...edges];
};
