import {AiApi, AiModel} from "./types";
import {AiContentType, ApiType} from "../";
import {ModelCategoryImageGeneration} from "./model-category";

export class StabilityAI implements AiApi {
  id: ApiType = "stability-ai";
  contentType: AiContentType = "per-model";
  name = "Stability AI";
  description = "Open models in every modality, for everyone, everywhere.";
  models: AiModel[] = [
    {
      id: "sdxl",
      displayName: "Stable Diffusion XL",
      description:
        "Stable Diffusion XL is a significant advancement in image generation capabilities, offering enhanced image composition and face generation that results in stunning visuals and realistic aesthetics.",
      categories: [ModelCategoryImageGeneration],
      prices: {image: 0.002},
      contentType: {
        input: ["text", "image"],
        output: ["image"],
      },
      parameters: [
        {
          key: "prompt",
          type: "text",
          mapTo: "prompt",
          displayName: "Prompt",
        },
        {
          key: "image",
          type: "image",
          displayName: "Image",
          mapTo: "image",
          description: "Input image for img2img or inpaint mode",
        },
        {
          key: "negative_prompt",
          displayName: "Negative prompt",
          type: "text",
          mapTo: "parameters",
          description: "Input Negative Prompt",
        },
        {
          key: "width",
          type: "range",
          displayName: "Width",
          defaultValue: 1024,
          limits: {
            step: 8,
            min: 128,
            max: 1600,
          },
        },
        {
          key: "height",
          type: "range",
          displayName: "Height",
          defaultValue: 1024,
          limits: {
            step: 8,
            min: 128,
            max: 1600,
          },
        },
        {
          key: "num_outputs",
          type: "range",
          displayName: "Number of generated images",
          defaultValue: 1,
          limits: {
            step: 1,
            min: 1,
            max: 4,
          },
        },
        {
          key: "scheduler",
          type: "select",
          // mapTo: "advanced",
          displayName: "Scheduler",
          defaultValue: "K_EULER",
          options: [
            {id: "DDIM"},
            {id: "DPMSolverMultistep"},
            {id: "HeunDiscrete"},
            {id: "KarrasDPM"},
            {id: "K_EULER_ANCESTRAL"},
            {id: "K_EULER"},
            {id: "PNDM"},
          ],
        },
        {
          key: "num_inference_steps",
          type: "range",
          // mapTo: "advanced",
          displayName: "Number of denoising steps",
          defaultValue: 50,
          limits: {
            step: 1,
            min: 1,
            max: 500,
          },
        },
        {
          key: "guidance_scale",
          type: "range",
          // mapTo: "advanced",
          displayName: "Guidance scale",
          description: "Scale for classifier-free guidance",
          defaultValue: 7.5,
          limits: {
            step: 0.1,
            min: 0,
            max: 50,
          },
        },
        {
          key: "prompt_strength",
          type: "range",
          // mapTo: "advanced",
          displayName: "Prompt strength",
          description:
            "Prompt strength when using img2img / inpaint. 1.0 corresponds to full destruction of information in image",
          defaultValue: 0.8,
          limits: {
            step: 0.1,
            min: 0,
            max: 1,
          },
        },
        // {
        //   key: "seed",
        //   type: "range",
        //   // mapTo: "advanced",
        //   displayName: "Seed",
        //   defaultValue: 0, // TODO: should be random by default
        //   limits: {
        //     step: 1,
        //     min: 0,
        //     max: 2147483647,
        //   },
        // },
        {
          key: "refine",
          type: "select",
          // mapTo: "advanced",
          displayName: "Refine",
          description: "Which refine style to use",
          defaultValue: "no_refiner",
          options: [
            {id: "no_refiner"},
            {id: "expert_ensemble_refiner"},
            {id: "base_image_refiner"},
          ],
        },
        {
          key: "high_noise_frac",
          type: "range",
          // mapTo: "advanced",
          displayName: "High noise fraction",
          description: "For expert_ensemble_refiner, the fraction of noise to use\n",
        },
        {
          key: "refine_steps",
          type: "number",
          // mapTo: "advanced",
          displayName: "Refine steps",
          description: "Number of refine steps to use",
          defaultValue: 10,
          limits: {
            step: 1,
            min: 1,
            max: 100,
          },
        },
        {
          key: "apply_watermark",
          type: "boolean",
          // mapTo: "advanced",
          displayName: "Apply watermark",
          description:
            "Applies a watermark to enable determining if an image is generated in downstream applications. If you have other provisions for generating or deploying images safely, you can use this to disable watermarking.",
          defaultValue: true,
        },
        {
          key: "disable_safety_checker",
          displayName: "Disable safety checker",
          type: "boolean",
          // mapTo: "advanced",
          description: "Disables the safety checker for generated images",
          defaultValue: false,
        },
      ],
    },
  ];
  readonly defaultModelId: string = "sdxl";

  getModel(modelId: string): AiModel {
    return this.models.find((model) => model.id === modelId) as AiModel;
  }
}

export type StabilityAIParameterKeys =
  | "prompt"
  | "negative_prompt"
  | "image"
  | "mask"
  | "width"
  | "height"
  | "num_outputs"
  | "scheduler"
  | "num_inference_steps"
  | "guidance_scale"
  | "prompt_strength"
  | "seed"
  | "high_noise_frac"
  | "refine"
  | "refine_steps"
  | "apply_watermark"
  | "lora_scale"
  | "disable_safety_checker";
