import React, { useEffect, useState } from 'react';
import { DownloadOutlined } from '@ant-design/icons';
import { Button, Drawer, Modal, Tabs, Tag } from 'antd';
import { formatNumberToLocaleString } from '@utils/math';
import { useUserEvent } from '@hooks/useUserEvent';
import { FilterProp, FiltersProp, FilterType, SelectedFilters } from '@constants/data-table';
import {
  EventControlComponent,
  EventControlDataTableElement,
  EventControlElement,
  UserEventType,
} from '@constants/event';
import {
  DOWNLOAD_PREDICTION_BUTTON_LABEL,
  EXPERIMENT_FILTER_CONTEXT_MATCH_LABEL,
  EXPERIMENT_FILTER_CONTEXT_SIMILARITY_LABEL,
  EXPERIMENT_FILTER_EVAL_MODE_LABEL,
  EXPERIMENT_FILTER_EXACT_MATCH_LABEL,
  EXPERIMENT_FILTER_F1_SCORE_LABEL,
  EXPERIMENT_FILTER_GROUNDEDNESS_LABEL,
  EXPERIMENT_FILTER_NO_ANSWER_LABEL,
  EXPERIMENT_FILTER_RANK_LABEL,
  FILTERS_HEADER_LABEL,
  FILTERS_LABEL,
  FILTERS_MODAL_HEADER_LABEL,
  PREDICTIONS_EVAL_MODE_NAMES,
  PREDICTIONS_FILTERS_FIELD_KEYS,
  PREDICTIONS_FILTERS_OPTIONS_KEYS,
  PREDICTIONS_HEADER,
  PREDICTIONS_SORTING_DATATABLE_OPTIONS_BY_NODE,
  PredictionType,
} from '@constants/experiments';
import { PipelineNodeTypes } from '@constants/pipelines';
import { contextRender } from '@components/common/contextRender/contextRender';
import DataTable from '@components/dataTable/DataTable';
import styles from './predictions.module.scss';

interface IPredictionsProps {
  evalRunId: string;
  results: any;
  predictions: any;
  retrieverNodeTopK: number;
  pipelineParameters: any;
  getPredictions: (
    nodeName: string,
    nodeType: string,
    currentPage: number,
    pageSize: number,
    searchValue: string,
    sortValue?: string,
    filterValues?: SelectedFilters,
  ) => void;
  downloadPredictions: (nodeName: string) => void;
  downloadDisabled: boolean;
  selectedSortValueByNode: Record<string, string>;
}

const Predictions = (props: IPredictionsProps) => {
  const {
    evalRunId,
    results,
    predictions,
    retrieverNodeTopK,
    pipelineParameters,
    getPredictions,
    downloadPredictions,
    downloadDisabled,
    selectedSortValueByNode,
  } = props;
  const { data, total } = predictions;
  const hasFilters =
    data?.length > 0 && data.some((row: any) => row.filters && Object.keys(row.filters).length > 0);
  const hasLabels =
    data?.length > 0 &&
    data.some((row: any) => row.labels && row.labels.length > 0 && row.labels[0].answer);
  const { trackUserEvent, appendEventProperties } = useUserEvent();

  const [selectedPredictionNode, setSelectedPredictionNodeInternal] = useState({
    node_name: '',
    node_type: '',
  });
  const setSelectedPredictionNode = (orig_node: { node_name: string; node_type: string }) => {
    const isGenerative =
      pipelineParameters && pipelineParameters[orig_node.node_name].type === 'PromptNode';
    let node = orig_node;
    if (isGenerative) {
      node = { node_name: orig_node.node_name, node_type: PipelineNodeTypes.GENERATIVE_NODE };
    }
    setSelectedPredictionNodeInternal(node);
  };
  const [filtersModal, setFiltersModal] = useState({
    open: false,
    filters: '',
  });
  const [predictionDrawerContent, setPredictionDrawerContent] = useState('');

  if ((!selectedPredictionNode || !selectedPredictionNode.node_name) && results.length > 0) {
    setSelectedPredictionNode(results[results.length - 1]);
  }

  useEffect(() => {
    if (evalRunId) appendEventProperties({ eval_run_id: evalRunId });
  }, [evalRunId]);

  const getFilters = (): FiltersProp | [] => {
    const { node_type: nodeType } = selectedPredictionNode;

    if (!nodeType) return [];

    const rankFilter = {
      type: FilterType.RANGE,
      key: PREDICTIONS_FILTERS_FIELD_KEYS.RANK,
      title: EXPERIMENT_FILTER_RANK_LABEL,
      rangeBoundaries: {
        min: 1,
        max: retrieverNodeTopK,
      },
      option: { key: 'rank', label: EXPERIMENT_FILTER_RANK_LABEL },
    } as FilterProp<FilterType.RANGE>;

    const exactMatchFilter = {
      type: FilterType.SELECT,
      key: PREDICTIONS_FILTERS_FIELD_KEYS.EXACT_MATCH,
      title: EXPERIMENT_FILTER_EXACT_MATCH_LABEL,
      options: [
        { key: 'exact_match_yes', label: 'Yes', value: true },
        { key: 'exact_match_no', label: 'No', value: false },
      ],
      style: { minWidth: '135px' },
    } as FilterProp<FilterType.SELECT>;

    const f1ScoreFilter = {
      type: FilterType.RANGE,
      key: PREDICTIONS_FILTERS_FIELD_KEYS.F1,
      title: EXPERIMENT_FILTER_F1_SCORE_LABEL,
      rangeBoundaries: {
        min: 0.0,
        max: 1.0,
      },
      rangeStep: 0.01,
      option: { key: 'f1', label: EXPERIMENT_FILTER_F1_SCORE_LABEL },
      style: { minWidth: '100px' },
    } as FilterProp<FilterType.RANGE>;

    const evalModeFilter = {
      type: FilterType.MULTI_SELECT,
      key: PREDICTIONS_FILTERS_FIELD_KEYS.EVAL_MODE,
      title: EXPERIMENT_FILTER_EVAL_MODE_LABEL,
      style: { minWidth: '115px' },
      options: [
        {
          key: PREDICTIONS_FILTERS_OPTIONS_KEYS.INTEGRATED,
          label: PREDICTIONS_EVAL_MODE_NAMES.INTEGRATED,
        },
        {
          key: PREDICTIONS_FILTERS_OPTIONS_KEYS.ISOLATED,
          label: PREDICTIONS_EVAL_MODE_NAMES.ISOLATED,
        },
      ],
    } as FilterProp<FilterType.MULTI_SELECT>;

    const noAnswerFilter = {
      type: FilterType.SELECT,
      key: PREDICTIONS_FILTERS_FIELD_KEYS.NO_ANSWER,
      title: EXPERIMENT_FILTER_NO_ANSWER_LABEL,
      options: [
        { key: 'no_answer_yes', label: 'Yes', value: 1.0 },
        { key: 'no_answer_no', label: 'No', value: 0.0 },
      ],
      style: { minWidth: '135px' },
    } as FilterProp<FilterType.SELECT>;

    const groundednessFilter = {
      type: FilterType.RANGE,
      key: PREDICTIONS_FILTERS_FIELD_KEYS.GROUNDEDNESS,
      title: EXPERIMENT_FILTER_GROUNDEDNESS_LABEL,
      rangeBoundaries: {
        min: 0.0,
        max: 1.0,
      },
      rangeStep: 0.01,
      option: { key: 'groundedness', label: EXPERIMENT_FILTER_GROUNDEDNESS_LABEL },
      style: { minWidth: '100px' },
    } as FilterProp<FilterType.RANGE>;

    const contextMatchFilter = {
      type: FilterType.SELECT,
      key: PREDICTIONS_FILTERS_FIELD_KEYS.CONTEXT_MATCH,
      title: EXPERIMENT_FILTER_CONTEXT_MATCH_LABEL,
      options: [
        { key: 'context_match_yes', label: 'Yes', value: true },
        { key: 'context_match_no', label: 'No', value: false },
      ],
      style: { minWidth: '140px' },
    } as FilterProp<FilterType.SELECT>;

    const contextSimilarityFilter = {
      type: FilterType.RANGE,
      key: PREDICTIONS_FILTERS_FIELD_KEYS.CONTEXT_SIMILARITY,
      title: EXPERIMENT_FILTER_CONTEXT_SIMILARITY_LABEL,
      rangeBoundaries: {
        min: 0,
        max: 100,
      },
      rangeStep: 0.01,
      option: { key: 'context_similarity', label: EXPERIMENT_FILTER_CONTEXT_SIMILARITY_LABEL },
      style: { minWidth: '160px' },
    } as FilterProp<FilterType.RANGE>;

    const documentFilters = [rankFilter, contextMatchFilter, contextSimilarityFilter];
    const extractiveFilters = [rankFilter, f1ScoreFilter, exactMatchFilter, evalModeFilter];
    const generativeFilters = [
      rankFilter,
      groundednessFilter,
      noAnswerFilter,
      ...(hasLabels ? [f1ScoreFilter] : []),
    ];

    let filters: FiltersProp;
    switch (nodeType) {
      case PipelineNodeTypes.DOCUMENT_NODE:
        filters = documentFilters;
        break;
      case PipelineNodeTypes.ANSWER_NODE:
        filters = extractiveFilters;
        break;
      case PipelineNodeTypes.GENERATIVE_NODE:
        filters = generativeFilters;
        break;
      default:
        filters = [];
    }
    return filters;
  };

  // TODO: Save state in redux
  const getDefaultSelectFilterValues = (): SelectedFilters => {
    const { node_type: nodeType } = selectedPredictionNode;
    if (nodeType !== PipelineNodeTypes.ANSWER_NODE) return {};

    return {
      [PREDICTIONS_FILTERS_FIELD_KEYS.EVAL_MODE]: [
        {
          key: PREDICTIONS_FILTERS_OPTIONS_KEYS.INTEGRATED,
          label: PREDICTIONS_EVAL_MODE_NAMES.INTEGRATED,
          value: PREDICTIONS_FILTERS_OPTIONS_KEYS.INTEGRATED,
          type: FilterType.MULTI_SELECT,
        },
      ],
    };
  };

  const getRetrieverExpectedContext = (prediction: any) => {
    const { labels } = prediction;
    if (!labels?.length) return '';
    const matchedLabel = labels.find((label: { state: string }) => label.state === 'MATCHED');
    return matchedLabel?.context ?? '';
  };

  const previewPredictionItem = (item: any) => {
    setPredictionDrawerContent(item);
  };

  const onPredictionsTabChange = (key: string) => {
    const selectedResult = results.find((result: any) => result.node_name === key);
    setSelectedPredictionNode(selectedResult);
    trackUserEvent({
      type: UserEventType.CLICK,
      control: `${EventControlComponent.PREDICTIONS_TABLE}/${EventControlElement.NODE_TAB}`,
      properties: {
        value: key,
      },
    });
  };

  const onDownloadPredictionsClick = (nodeName: string) => {
    downloadPredictions(nodeName);

    trackUserEvent({
      type: UserEventType.CLICK,
      control: `${EventControlComponent.PREDICTIONS_TABLE}/${EventControlElement.DOWNLOAD}`,
      properties: {
        value: nodeName,
      },
    });
  };

  const handleCancelFiltersModal = () => {
    setFiltersModal({
      open: false,
      filters: '',
    });
  };

  const renderRetrieverPredictedContext = ({ context }: { context: string }) =>
    contextRender({ content: context, previewButtonHandler: previewPredictionItem });

  const renderReaderPredictedContext = ({
    context,
    answer_start,
    answer_end,
  }: {
    context: string;
    answer_start: number;
    answer_end: number;
  }) =>
    contextRender({
      content: context,
      answerStart: answer_start,
      answerEnd: answer_end,
      previewButtonHandler: previewPredictionItem,
    });

  const { node_type: nodeType, node_name: nodeName } = selectedPredictionNode;

  const queryColumn = {
    title: 'Query',
    dataIndex: 'query',
    key: 'query',
    width: nodeType === PipelineNodeTypes.ANSWER_NODE ? 300 : 400,
    render: (text: string) =>
      contextRender({
        content: text,
        maxCharacters: 250,
        previewButtonHandler: previewPredictionItem,
      }),
  };

  const filtersColumn = {
    title: FILTERS_HEADER_LABEL,
    dataIndex: 'filters',
    key: 'filters',
    width: 200,
    render: (filters: any) => {
      if (!filters || Object.keys(filters).length === 0) {
        return null;
      }
      const showModal = () => {
        setFiltersModal({
          open: true,
          filters,
        });
      };

      return (
        <Button type="link" onClick={showModal}>
          {FILTERS_LABEL} ({Object.keys(filters).length})
        </Button>
      );
    },
  };

  const predictedAnswerColumn = {
    title: 'Predicted Answer',
    dataIndex: 'answer',
    key: 'answer',
    width: 300,
    render: (text: string) =>
      contextRender({
        content: text,
        maxCharacters: 250,
        previewButtonHandler: previewPredictionItem,
      }),
  };

  const expectedAnswerColumn = {
    title: 'Expected Answers',
    dataIndex: 'labels',
    key: 'expected_answers',
    width: 300,
    fixed: 'true' as const,
    render: (labels: any) => {
      const answers = labels.map((label: any) => label.answer).join();
      return contextRender({
        content: answers,
        maxCharacters: 250,
        previewButtonHandler: previewPredictionItem,
      });
    },
  };

  const promptColumn = {
    title: 'Prompt',
    dataIndex: 'prompt',
    key: 'prompt',
    width: 300,
    render: (text: string) =>
      contextRender({
        content: text,
        maxCharacters: 250,
        previewButtonHandler: previewPredictionItem,
      }),
  };

  const predictedContextColumn = {
    title: 'Predicted Context',
    dataIndex: 'context',
    key: 'context',
    width: nodeType === PipelineNodeTypes.DOCUMENT_NODE ? 600 : 500,
    render: (_: string, item: any) => {
      return item.prediction_type === PredictionType.ANSWER
        ? renderReaderPredictedContext(item)
        : renderRetrieverPredictedContext(item);
    },
  };

  const expectedContextColumn = {
    title: 'Expected Context',
    width: 600,
    render: (_: string, item: any) => {
      return renderRetrieverPredictedContext({
        context: getRetrieverExpectedContext(item),
      });
    },
  };

  const f1Column = {
    title: 'F1',
    dataIndex: 'f1',
    key: 'f1',
    width: 90,
    fixed: 'true',
    render: (value: number) => (
      <span>{value === null || value === undefined ? '' : formatNumberToLocaleString(value)}</span>
    ),
  };

  const sasColumn = {
    title: 'SAS',
    dataIndex: 'sas',
    key: 'sas',
    width: 90,
    fixed: 'true',
    render: (value: number) => (
      <span>{value === null || value === undefined ? '' : formatNumberToLocaleString(value)}</span>
    ),
  };

  const exactMatchColumn = {
    title: 'Exact Match',
    dataIndex: 'exact_match',
    key: 'exact_match',
    width: 90,
    render: (value: boolean) => (
      <Tag color={value ? 'success' : 'default'}>{value ? 'Yes' : 'No'}</Tag>
    ),
  };

  const evalModeColumn = {
    title: 'Eval Mode',
    dataIndex: 'eval_mode',
    key: 'eval_mode',
    width: 120,
    render: (value: keyof typeof PREDICTIONS_EVAL_MODE_NAMES) => (
      <span>{PREDICTIONS_EVAL_MODE_NAMES[value]}</span>
    ),
  };

  const groundednessColumn = {
    title: 'Groundedness',
    dataIndex: 'groundedness',
    key: 'groundedness',
    width: 90,
    fixed: 'true',
    render: (value: number) => (
      <span>{value === null || value === undefined ? '' : formatNumberToLocaleString(value)}</span>
    ),
  };

  const noAnswerColumn = {
    title: 'No Answer',
    dataIndex: 'no_answer_ratio',
    key: 'no_answer_ratio',
    width: 90,
    render: (value: boolean) => (
      <Tag color={value ? 'success' : 'default'}>{value ? 'Yes' : 'No'}</Tag>
    ),
  };

  const queryLatencyColumn = {
    title: 'Query Latency',
    dataIndex: 'query_latency',
    key: 'query_latency',
    width: 90,
    fixed: 'true',
    render: (value: number) => (
      <span>{value === null || value === undefined ? '' : formatNumberToLocaleString(value)}</span>
    ),
  };

  const contextSimilarityColumn = {
    title: 'Context Similarity',
    dataIndex: 'labels',
    key: 'label',
    width: 100,
    render: (labels: any) => {
      let highestSimilarity = 0;
      labels.forEach((label: any) => {
        if (label.context_similarity > highestSimilarity)
          highestSimilarity = label.context_similarity;
      });
      return `${formatNumberToLocaleString(highestSimilarity)} %`;
    },
  };

  const contextMatchColumn = {
    title: 'Context Match',
    dataIndex: 'context_match',
    key: 'context_match',
    width: 90,
    render: (value: boolean) => (
      <Tag color={value ? 'success' : 'default'}>{value ? 'Yes' : 'No'}</Tag>
    ),
  };

  const rankColumn = {
    title: 'Rank',
    dataIndex: 'rank',
    key: 'rank',
    width: 65,
  };

  const documentColumns = [
    predictedContextColumn,
    expectedContextColumn,
    contextSimilarityColumn,
    contextMatchColumn,
  ];

  const extractiveColumns = [
    predictedAnswerColumn,
    expectedAnswerColumn,
    predictedContextColumn,
    f1Column,
    sasColumn,
    exactMatchColumn,
    evalModeColumn,
  ];

  const generativeColumns = [
    predictedAnswerColumn,
    ...(hasLabels ? [expectedAnswerColumn] : []),
    promptColumn,
    groundednessColumn,
    noAnswerColumn,
    ...(hasLabels ? [sasColumn, f1Column] : []),
    queryLatencyColumn,
  ];

  let nodeSpecificColumns: any[];
  switch (nodeType) {
    case PipelineNodeTypes.DOCUMENT_NODE:
      nodeSpecificColumns = documentColumns;
      break;
    case PipelineNodeTypes.ANSWER_NODE:
      nodeSpecificColumns = extractiveColumns;
      break;
    case PipelineNodeTypes.GENERATIVE_NODE:
      nodeSpecificColumns = generativeColumns;
      break;
    default:
      nodeSpecificColumns = [];
  }

  const columns = [
    queryColumn,
    ...(hasFilters ? [filtersColumn] : []),
    ...nodeSpecificColumns,
    rankColumn,
  ];

  const renderFiltersModal = () => {
    const { open: filtersModalOpen, filters: filtersModalFilters } = filtersModal;
    return (
      <Modal
        title={FILTERS_MODAL_HEADER_LABEL}
        open={filtersModalOpen}
        centered
        footer={null}
        onCancel={handleCancelFiltersModal}
      >
        <p>
          <pre className={styles.jsonView}>{JSON.stringify(filtersModalFilters, null, 2)}</pre>
        </p>
      </Modal>
    );
  };

  const sortingKey = nodeType as keyof typeof PREDICTIONS_SORTING_DATATABLE_OPTIONS_BY_NODE;
  const sortingOptions = PREDICTIONS_SORTING_DATATABLE_OPTIONS_BY_NODE[sortingKey];
  const selectedSortValue = selectedSortValueByNode[nodeType];

  const tabItems = results.map((item: any) => ({ label: item.node_name, key: item.node_name }));

  const renderPredictionDetailsDrawer = () => (
    <Drawer
      open={!!predictionDrawerContent}
      onClose={() => setPredictionDrawerContent('')}
      size="large"
      className={styles.detailsDrawer}
    >
      <pre>{predictionDrawerContent}</pre>
    </Drawer>
  );

  return (
    <section className={styles.section}>
      <div>
        <h5> {PREDICTIONS_HEADER} </h5>

        <Tabs activeKey={nodeName} type="card" onChange={onPredictionsTabChange} items={tabItems} />
      </div>
      {selectedPredictionNode && selectedPredictionNode.node_name && (
        <DataTable
          id={nodeName}
          data={data}
          columns={columns}
          total={total}
          sorting={
            sortingOptions && selectedSortValue
              ? {
                  selectedValue: selectedSortValue,
                  options: sortingOptions,
                }
              : undefined
          }
          filters={getFilters()}
          selectedFiltersValues={getDefaultSelectFilterValues()}
          searchAvailable
          getData={(
            currentPage: number,
            pageSize: number,
            searchValue: string,
            sortValue?: string,
            filterValues?: SelectedFilters,
          ) =>
            getPredictions(
              nodeName,
              nodeType,
              currentPage,
              pageSize,
              searchValue,
              sortValue,
              filterValues,
            )
          }
          rowSelection={false}
          rowKey="prediction_id"
          primaryAction={{
            label: DOWNLOAD_PREDICTION_BUTTON_LABEL,
            onClick: () => onDownloadPredictionsClick(nodeName),
            secondary: true,
            disabled: downloadDisabled,
            icon: <DownloadOutlined />,
          }}
          scroll={{ x: 'max-content' }}
          userEventsTrackingHandlers={{
            onSearch: (value) =>
              trackUserEvent({
                type: UserEventType.KEYDOWN,
                control: `${EventControlComponent.PREDICTIONS_TABLE}/${EventControlDataTableElement.SEARCH}`,
                properties: {
                  value,
                  node_name: nodeName,
                },
              }),
            onPageChange: (value) =>
              trackUserEvent({
                type: UserEventType.CLICK,
                control: `${EventControlComponent.PREDICTIONS_TABLE}/${EventControlDataTableElement.PAGINATION}`,
                properties: {
                  value,
                  node_name: nodeName,
                },
              }),
            onSort: (value) =>
              trackUserEvent({
                type: UserEventType.CLICK,
                control: `${EventControlComponent.PREDICTIONS_TABLE}/${EventControlDataTableElement.SORT}`,
                properties: {
                  value,
                  node_name: nodeName,
                },
              }),
            onFilter: (filterKey) =>
              trackUserEvent({
                type: UserEventType.CLICK,
                control: `${EventControlComponent.PREDICTIONS_TABLE}/${EventControlDataTableElement.FILTER}`,
                properties: {
                  filter_key: filterKey,
                  node_name: nodeName,
                },
              }),
            onClearFilter: (filterKey) =>
              trackUserEvent({
                type: UserEventType.CLICK,
                control: `${EventControlComponent.PREDICTIONS_TABLE}/${EventControlDataTableElement.CLEAR_FILTER}`,
                properties: {
                  filter_key: filterKey,
                  node_name: nodeName,
                },
              }),
            onClearAllFilters: () =>
              trackUserEvent({
                type: UserEventType.CLICK,
                control: `${EventControlComponent.PREDICTIONS_TABLE}/${EventControlDataTableElement.CLEAR_ALL_FILTERS}`,
                properties: {
                  node_name: nodeName,
                },
              }),
          }}
          border
        />
      )}
      {renderFiltersModal()}
      {renderPredictionDetailsDrawer()}
    </section>
  );
};

export default Predictions;
