import axios, { AxiosInstance } from "axios";
import { pick } from "lodash";
import { config } from "../../config";
import Constants from "../../constants";
import { RequestType } from "../../types/request-type";
import { uuid } from "../../utils/crypto-utils";
import {
  ApiOptions,
  CommonInferenceResult,
  ExperimentOptions,
  AllOptions,
  Params,
  RhymesInferenceResult,
  ThesaurusInferenceResult,
  IDiversity,
} from "./types";

export const inferenceApiClient = axios.create({
  baseURL: config.inferenceApiHost,
  withCredentials: true,
});

export const demoInferenceApiClient = axios.create({
  baseURL: config.demoInferenceApiHost,
});

function dispatchRequestEvent(route: string) {
  dispatchEvent(
    new CustomEvent(Constants.INFERENCE_API_REQUEST_EVENT, {
      detail: { route },
    })
  );
}

function getParams(
  word: string | null,
  context: string | null,
  options: AllOptions | null,
  paramsList: string[]
): Params {
  const all = {
    lyrics: context ?? undefined,
    genre: options?.genre || Constants.DEFAULT_GENRE,
    n_verses: config.nSuggestionsPerFetch,
    remove_profanity:
      options?.filterProfanity ?? Constants.DEFAULT_FILTER_PROFANITY,
    diversity: options?.diversity || Constants.DEFAULT_DIVERSITY,
    word: word ?? undefined,
    "rhyme-word": word ?? undefined,
    group: options?.group,
    beta: options?.beta,
    part_of_speech: null,
    syllables: options?.metric ?? undefined,
    mood: options?.mood || undefined,
    sample_size: config.rhymesSample || undefined,
  };

  return pick(all, ...paramsList);
}

function getClient(options: ApiOptions): AxiosInstance {
  return options.demo ? demoInferenceApiClient : inferenceApiClient;
}

async function fetchLines(
  context: string,
  options: AllOptions
): Promise<CommonInferenceResult> {
  const params = getParams(null, context, options, [
    "lyrics",
    "genre",
    "n_verses",
    "remove_profanity",
    "diversity",
    "group",
    "beta",
    "syllables",
    "mood",
  ]);

  const { data, headers } = await getClient(options).post("line", params);

  dispatchRequestEvent("auto_lines");

  return {
    requestId: uuid(),
    requestType: RequestType.AUTO_LINES,
    params,
    data,
    headers,
  };
}

async function fetchWords(
  context: string,
  options: AllOptions
): Promise<CommonInferenceResult> {
  const params = getParams(null, context, options, [
    "lyrics",
    "genre",
    "n_verses",
    "remove_profanity",
    "diversity",
    "model",
    "group",
    "beta",
  ]);

  const { data, headers } = await getClient(options).post("word", params);

  dispatchRequestEvent("auto_words");

  return {
    requestId: uuid(),
    requestType: RequestType.AUTO_WORDS,
    params,
    data,
    headers,
  };
}

async function fetchRhymes(
  word: string,
  context: string,
  options: AllOptions
): Promise<RhymesInferenceResult> {
  const params = getParams(word, context, options, [
    "lyrics",
    "genre",
    "n_verses",
    "remove_profanity",
    "diversity",
    "rhyme-word",
    "model",
    "group",
    "beta",
    "syllables",
    "mood",
    "random",
    "sample_size",
  ]);

  const { data, headers } = await getClient(options).post("rhyme", params);

  dispatchRequestEvent("rhymes");

  return {
    requestId: uuid(),
    requestType: RequestType.RHYMES,
    params,
    data,
    headers,
  };
}

async function fetchRhymeWords(
  word: string,
  context: string,
  options: AllOptions
): Promise<CommonInferenceResult<string>> {
  const params = getParams(word, context, options, [
    "lyrics",
    "word",
    "remove_profanity",
    "model",
    "group",
    "beta",
  ]);

  const { data, headers } = await getClient(options).post(
    "query_rhyme_db",
    params
  );

  dispatchRequestEvent("rhyme_words");

  return {
    requestId: uuid(),
    requestType: RequestType.RHYME_WORDS,
    params,
    data,
    headers,
  };
}

async function fetchThesaurus(
  word: string,
  context: string,
  options: ExperimentOptions & ApiOptions
): Promise<ThesaurusInferenceResult> {
  const params = getParams(word, context, options, [
    "lyrics",
    "word",
    "part_of_speech",
    "group",
    "beta",
  ]);

  const { data, headers } = await getClient(options).post(
    "query_thesaurus_db",
    params
  );

  dispatchRequestEvent("thesaurus");

  return {
    requestId: uuid(),
    requestType: RequestType.THESAURUS,
    params,
    data,
    headers,
  };
}

async function fetchGenres(
  options: ApiOptions
): Promise<CommonInferenceResult<string>> {
  const { data, headers } = await getClient(options).get("get_genres");
  return {
    requestId: uuid(),
    requestType: RequestType.GENRES,
    params: {},
    data,
    headers,
  };
}

async function fetchDiversity(
  options: ApiOptions
): Promise<CommonInferenceResult<IDiversity>> {
  const { data, headers } = await getClient(options).get("get_diversities");
  return {
    requestId: uuid(),
    requestType: RequestType.GENRES,
    params: {},
    data,
    headers,
  };
}

async function checkHealth(options: ApiOptions): Promise<void> {
  await getClient(options).get("/");
}

const InferenceApi = {
  fetchLines,
  fetchWords,
  fetchRhymes,
  fetchRhymeWords,
  fetchThesaurus,
  fetchGenres,
  checkHealth,
  fetchDiversity,
};

export default InferenceApi;
