import compact from 'lodash/compact'
import countBy from 'lodash/countBy'
import extend from 'lodash/extend'
import flatten from 'lodash/flatten'
import toNumber from 'lodash/toNumber'
import { createSlice } from '@reduxjs/toolkit'
import { navigate } from 'gatsby'

import callFlask from '../callFlask'
import callHasura from '../callHasura'
import { MedicalImageViews_medical_image_views } from '../graphQlQueries/types/MedicalImageViews'
import { MedicalImage_medical_images_by_pk } from '../graphQlQueries/types/MedicalImage'
import { MedicalImagesForCondition_medical_images } from '../graphQlQueries/types/MedicalImagesForCondition'
import { MedicalImagesUrls_medical_images } from '../graphQlQueries/types/MedicalImagesUrls'
import { MedicalImages_medical_images } from '../graphQlQueries/types/MedicalImages'
import { NotificationId, setNotificationAction } from './notifications'
import { TrainingIteration_training_iterations_by_pk } from '../graphQlQueries/types/TrainingIteration'
import { TrainingIterations_training_iterations } from '../graphQlQueries/types/TrainingIterations'
import { User_users } from '../graphQlQueries/types/User'
import { defaultSetLoading, defaultNetworkingFailure, defaultNetworkingSuccess } from './common'
import { fetchImagePermutationsAction, fetchImagesAction } from './consultations'
import { fetchUserAction, postSlackMessageAction } from './users'
import { keyForAwsS3Url } from '../../lib/helpers'
import { TrainingIterationMetrics_training_iterations } from '../graphQlQueries/types/TrainingIterationMetrics'

import {
  claimTrainingIterationQuery,
  completeTrainingIterationQuery,
  fetchMedicalImageQuery,
  FetchMedicalImagesParams,
  fetchMedicalImagesQuery,
  fetchMedicalImageUrlsQuery,
  fetchMedicalImageViewsQuery,
  fetchTrainingIterationMetricsQuery,
  fetchTrainingIterationsQuery,
  medicalImagesForConditionQuery,
  trainingIterationQuery,
  updateMedicalImageFlaggedQuery,
  updateMedicalImageSpeciesQuery,
  updateMedicalImageViewQuery,
  updateTrainingIterationFeedbackQuery,
} from '../graphQlQueries/MedicalImage'

export interface MedicalImagesState {
  binaryTaggingPresignedUrls: string[]
  hasErrors: boolean
  isQuerying: any
  loading: boolean
  medicalImage?: MedicalImage_medical_images_by_pk
  medicalImageViews: MedicalImageViews_medical_image_views[]
  medicalImages: MedicalImages_medical_images[]
  medicalImagesForCondition?: MedicalImagesForCondition_medical_images[]
  trainingIteration?: TrainingIteration_training_iterations_by_pk
  trainingIterations?: TrainingIterations_training_iterations[]
  trainingIterationMetrics?: TrainingIterationMetrics_training_iterations[]
}

const initialState: MedicalImagesState = {
  isQuerying: {},
  loading: false,
  hasErrors: false,
  medicalImages: [],
  medicalImageViews: [],
  binaryTaggingPresignedUrls: [],
}

const medicalImagesSlice = createSlice({
  name: 'medicalImages',
  initialState,
  reducers: {
    setLoading: defaultSetLoading,
    networkingFailure: defaultNetworkingFailure,
    networkingSuccess: defaultNetworkingSuccess,

    fetchMedicalImageSuccess: (state, { payload }: { payload: MedicalImage_medical_images_by_pk }) => {
      state.medicalImage = payload
      state.loading = false
      state.hasErrors = false
    },

    fetchMedicalImagesSuccess: (state, { payload }: { payload: MedicalImages_medical_images[] }) => {
      state.medicalImages = payload
      state.loading = false
      state.hasErrors = false
    },

    unsetMedicalImages: (state) => {
      state.medicalImages = []
      state.medicalImage = undefined
    },

    fetchMedicalImageViewsSuccess: (state, { payload }: { payload: MedicalImageViews_medical_image_views[] }) => {
      state.medicalImageViews = payload
    },

    fetchTrainingIterationSuccess: (state, { payload }: { payload: TrainingIteration_training_iterations_by_pk }) => {
      state.trainingIteration = payload
    },

    fetchTrainingIterationMetricsSuccess: (state, { payload }: { payload: TrainingIterationMetrics_training_iterations[] }) => {
      state.trainingIterationMetrics = payload
    },

    fetchTrainingIterationsSuccess: (state, { payload }: { payload: TrainingIterations_training_iterations[] }) => {
      state.trainingIterations = payload
    },

    fetchMedicalImagesForConditionSuccess: (state, { payload }: { payload: MedicalImagesForCondition_medical_images[] }) => {
      state.medicalImagesForCondition = payload
    },

    setBinaryTaggingPresignedUrls: (state, { payload }: { payload: string[] }) => {
      state.binaryTaggingPresignedUrls = payload
    },
  },
})

export const {
  fetchMedicalImagesSuccess,
  fetchMedicalImageSuccess,
  fetchMedicalImageViewsSuccess,
  unsetMedicalImages,
  setLoading,
  networkingSuccess,
  networkingFailure,
  fetchTrainingIterationMetricsSuccess,
  fetchTrainingIterationSuccess,
  fetchTrainingIterationsSuccess,
  fetchMedicalImagesForConditionSuccess,
  setBinaryTaggingPresignedUrls,
} = medicalImagesSlice.actions

export const medicalImagesSelector = (state: any) => state.medicalImages

export default medicalImagesSlice.reducer

export function fetchMedicalImageAction(accessToken: string, id: number) {
  return async (dispatch: any) => {
    const query = fetchMedicalImageQuery(id)
    dispatch(setLoading(query.name))

    try {
      const result: MedicalImage_medical_images_by_pk = await callHasura(accessToken, query)
      if (result.aws_s3_url) dispatch(fetchImagesAction([keyForAwsS3Url(result.aws_s3_url)!], undefined, true))
      dispatch(fetchMedicalImageSuccess(result))
      dispatch(networkingSuccess(query.name))
    } catch (error) {
      dispatch(networkingFailure(query.name))
    }
  }
}

export function updateMedicalImageFlaggedAction(accessToken: string, id: number, flagged: boolean) {
  const query = updateMedicalImageFlaggedQuery(id, flagged)
  return async (dispatch: any) => {
    dispatch(setLoading(query.name))

    try {
      await callHasura(accessToken, query)
      await dispatch(fetchMedicalImageAction(accessToken, id))
      dispatch(networkingSuccess(query.name))
    } catch (error) {
      dispatch(networkingFailure(query.name))
    }
  }
}

export function updateMedicalImageViewAction(accessToken: string, id: number, view_id: number) {
  const query = updateMedicalImageViewQuery(id, view_id)
  return async (dispatch: any) => {
    dispatch(setLoading(query.name))

    try {
      await callHasura(accessToken, query)
      await dispatch(fetchMedicalImageAction(accessToken, id))
      dispatch(networkingSuccess(query.name))
    } catch (error) {
      dispatch(networkingFailure(query.name))
    }
  }
}

export function fetchMedicalImagesAction(accessToken: string, params?: FetchMedicalImagesParams, ids?: number[]) {
  return async (dispatch: any) => {
    const query = fetchMedicalImagesQuery(params, ids)
    dispatch(setLoading(query.name))

    try {
      const result: MedicalImages_medical_images[] = await callHasura(accessToken, query)

      if (!result.length && params?.inConsult) {
        return dispatch(fetchMedicalImagesAction(accessToken, extend({}, params, { inConsult: false })))
      }

      dispatch(networkingSuccess(query.name))

      if (result.length) {
        dispatch(fetchMedicalImagesSuccess(result))
      } else {
        dispatch(setNotificationAction(NotificationId.NoResults))
      }
    } catch (error) {
      dispatch(networkingFailure(query.name))
    }
  }
}

export function fetchMedicalImageViewsAction(accessToken: string) {
  return async (dispatch: any) => {
    const query = fetchMedicalImageViewsQuery()
    dispatch(setLoading(query.name))

    try {
      const result: MedicalImageViews_medical_image_views[] = await callHasura(accessToken, query)
      dispatch(fetchMedicalImageViewsSuccess(result))
      dispatch(networkingSuccess(query.name))
    } catch (error) {
      dispatch(networkingFailure(query.name))
    }
  }
}

export function fetchTrainingIterationAction(accessToken: string, id: number) {
  return async (dispatch: any) => {
    dispatch(setLoading())

    try {
      const result: TrainingIteration_training_iterations_by_pk = await callHasura(accessToken, trainingIterationQuery(id))
      dispatch(fetchTrainingIterationSuccess(result))
      dispatch(networkingSuccess())
    } catch (error) {
      dispatch(networkingFailure())
    }
  }
}

export function claimTrainingIterationAction(accessToken: string, id: number, userId: string) {
  return async (dispatch: any) => {
    dispatch(setLoading())

    try {
      const result: TrainingIteration_training_iterations_by_pk = await callHasura(accessToken, trainingIterationQuery(id))
      if (result.vet_id) {
        dispatch(fetchTrainingIterationsAction(accessToken))
        dispatch(setNotificationAction(NotificationId.AlreadyClaimed))
      } else {
        await callHasura(accessToken, claimTrainingIterationQuery(id, userId))
        navigate(`/tagging/binary/?i=${id}`)
      }
      dispatch(networkingSuccess())
    } catch (error) {
      dispatch(networkingFailure())
    }
  }
}

export function updateTrainingIterationFeedbackAction(accessToken: string, id: number, feedback: string) {
  return async (dispatch: any) => {
    dispatch(setLoading())

    try {
      await callHasura(accessToken, updateTrainingIterationFeedbackQuery(id, feedback))
      dispatch(networkingSuccess())
    } catch (error) {
      dispatch(networkingFailure())
    }
  }
}

export function fetchTrainingIterationsAction(accessToken: string) {
  return async (dispatch: any) => {
    const query = fetchTrainingIterationsQuery()
    dispatch(setLoading(query.name))

    try {
      const result: TrainingIterations_training_iterations[] = await callHasura(accessToken, query)
      dispatch(fetchTrainingIterationsSuccess(result))
      dispatch(networkingSuccess(query.name))
    } catch (error) {
      dispatch(networkingFailure(query.name))
    }
  }
}

export function fetchTrainingIterationMetricsAction(accessToken: string) {
  return async (dispatch: any) => {
    const query = fetchTrainingIterationMetricsQuery()
    dispatch(setLoading(query.name))

    try {
      const result: TrainingIterations_training_iterations[] = await callHasura(accessToken, query)
      dispatch(fetchTrainingIterationMetricsSuccess(result))
      dispatch(networkingSuccess(query.name))
    } catch (error) {
      dispatch(networkingFailure(query.name))
    }
  }
}

export function checkTrainingIterationCompletedAction(accessToken: string, id: number, user: User_users) {
  return async (dispatch: any) => {
    dispatch(setLoading())

    try {
      const trainingIteration: TrainingIteration_training_iterations_by_pk = await callHasura(accessToken, trainingIterationQuery(id))
      const completedIds = trainingIteration.predictions_normalizeds
        .filter((p) => p.vet_id === user.id)
        .map((p) => p.medical_images_id)
      const isComplete = trainingIteration.medical_image_ids_denormalized
        .split(',')
        .map(toNumber)
        .every((n) => completedIds.includes(n))

      if (!trainingIteration.completed_at && isComplete) {
        await callHasura(accessToken, completeTrainingIterationQuery(id))
        const msg = `${user.display_name} finished ${trainingIteration.condition.display_name} ${
          trainingIteration.species
        } tagging set (${JSON.stringify(countBy(trainingIteration.predictions_normalizeds.map((p) => p.grade || p.issue)))}) (id: ${
          trainingIteration.id
        }).`
        dispatch(postSlackMessageAction(msg))
        dispatch(fetchUserAction(accessToken, user.id))
      }

      dispatch(networkingSuccess())
    } catch (error) {
      dispatch(networkingFailure())
    }
  }
}

export function updateMedicalImageSpeciesAction(accessToken: string, id: number, species: string) {
  return async (dispatch: any) => {
    dispatch(setLoading())

    try {
      await callHasura(accessToken, updateMedicalImageSpeciesQuery(id, species))
      dispatch(networkingSuccess())
    } catch (error) {
      dispatch(networkingFailure())
    }
  }
}

export function unsetMedicalImagesAction() {
  return async (dispatch: any) => {
    dispatch(unsetMedicalImages())
  }
}

export function medicalImagesForConditionAction(accessToken: string, id: number, species: string) {
  return async (dispatch: any) => {
    const query = medicalImagesForConditionQuery(id, species)
    dispatch(setLoading(query.name))

    try {
      const result: MedicalImagesForCondition_medical_images[] = await callHasura(accessToken, query)
      const keys = compact(result.map((r) => keyForAwsS3Url(r.aws_s3_url)))
      const permutations = flatten(
        result.map((r) =>
          r.medical_image_permutations.filter((p) => p.aws_s3_url).map((p) => [p.label, keyForAwsS3Url(p.aws_s3_url)!])
        )
      )
      dispatch(fetchImagesAction(keys))
      dispatch(fetchImagePermutationsAction(permutations))
      dispatch(fetchMedicalImagesForConditionSuccess(result))
      dispatch(networkingSuccess(query.name))
    } catch (error) {
      dispatch(networkingFailure(query.name))
    }
  }
}

export function preloadMedicalImagesAction(accessToken: string, ids: number[], yoloLabel: string | null = null) {
  return async (dispatch: any) => {
    dispatch(setLoading())

    try {
      const images: MedicalImagesUrls_medical_images[] = await callHasura(accessToken, fetchMedicalImageUrlsQuery(ids))
      const params: any = { keys: compact(images.map((i) => keyForAwsS3Url(i.aws_s3_url || undefined))) }
      if (yoloLabel && yoloLabel !== 'original') {
        params['bucket'] = `radimal-model-images-cropped-${yoloLabel}`
      } else {
        params['bucket'] = 'radimal-model-images'
      }
      const presignedUrls = await callFlask(`/images/presigned`, 'POST', params)
      const ordered = compact(
        ids.map((id) =>
          presignedUrls.find((p: string) => {
            const key = keyForAwsS3Url(images.find((i) => i.id === id)?.aws_s3_url)
            return key && p.includes(key)
          })
        )
      )
      dispatch(setBinaryTaggingPresignedUrls(ordered))
      dispatch(networkingSuccess())
    } catch (error) {
      dispatch(networkingFailure())
    }
  }
}
