from PIL import Image
import pathlib
def extract_colors(image_path, colors_file, index, output_path):
    # Load the image
    img = Image.open(image_path).convert("RGBA")  # Ensure image has an alpha channel
    width, height = img.size

    # Read the colors from the text file
    with open(colors_file, "r") as file:
        lines = file.readlines()
    
    # Ensure the index is valid
    if index < -1 or index >= len(lines):
        print(f"Error: Index {index} is out of range. The file contains {len(lines)} entries.")
        return
    if index>=0:
        # Get the color list for the given index
        target_colors = lines[index].strip().split(",")
        target_colors = [color.strip().upper() for color in target_colors]  # Normalize colors to uppercase
    else:
        # If index is -1, extract all colors
        target_colors = [line.strip().upper() for line in lines if line.strip()]

    # Create a new blank image with transparency
    new_img = Image.new("RGBA", (width, height), (0, 0, 0, 0))
    num_kept = 0
    num_discarded = 0
    num_transparent = 0
    found_partial_transparent = False
    # Iterate through all pixels in the original image
    for y in range(height):
        for x in range(width):
            r, g, b, a = img.getpixel((x, y))
            if a != 0:  # Only consider non-transparent pixels
                if a < 255:
                    if not found_partial_transparent:
                        print("WARNING: Found partial transparency in the image. This may affect color matching.")
                        found_partial_transparent = True
                hex_color = f"FF{r:02X}{g:02X}{b:02X}"  # Convert RGB to hex with "FF" alpha
                if hex_color in target_colors:
                    new_img.putpixel((x, y), (r, g, b, a))  # Keep the pixel
                    num_kept += 1
                else:
                    num_discarded += 1
            else:
                num_transparent += 1
    # Print statistics
    print(f"Total pixels kept: {num_kept}")
    print(f"Total pixels discarded: {num_discarded}")
    print(f"Total transparent pixels: {num_transparent}")
    # Save the resulting image
    new_img.save(output_path)
    print(f"Extracted image saved to {output_path}")
    return num_kept, num_discarded, num_transparent


import argparse


def generate_paid_and_free(filename):
    image_path = filename
    colors_file_free = "palette.txt"
    colors_file_paid = "palette_only_paid.txt"
    fpath = pathlib.Path(image_path)
    #append _only_free/_only_paid to filename
    output_path_free = fpath.stem + "_only_free.png"
    output_path_paid = fpath.stem + "_only_paid.png"
    print("Free version:")
    kept,discard,transparent=extract_colors(image_path, colors_file_free, -1, output_path_free)
    print("Paid version:")
    kept2,discard2,transparent2=extract_colors(image_path, colors_file_paid, -1, output_path_paid)
    if kept != discard2 or kept2 != discard:
        print("WARNING: The number of kept pixels does not match the number of discarded pixels in the free and paid images, there are likely pixels outside the palette.")
    if transparent != transparent2:
        print("WARNING: The number of transparent pixels does not match in the free and paid images.")
    print(f"Total number of non-transparent pixels in the original image: {kept + discard}")
    print(f"Total number of pixels in the original image: {kept + discard + transparent}")

def clean_qr_code(qr_path):
    palette_file="palette.txt"
    idx=0
    output_path=qr_path#overwrite same file
    extract_colors(qr_path, palette_file, idx, output_path)

# generate_paid_and_free()
def normal_mode(image_path, colors_file, output_path, index):
    """
    Extract colors using a specific palette index.
    """
    extract_colors(image_path, colors_file, index, output_path)

# Set up argparse
parser = argparse.ArgumentParser(description="Extract colors from an image based on a color palette.")
subparsers = parser.add_subparsers(dest="command", help="Available commands")

# Subcommand: generate_paid_and_free
parser_generate = subparsers.add_parser("generate", help="Generate paid and free versions of the image.")
parser_generate.add_argument("filename", type=str, help="Path to the input image.")
parser_generate.set_defaults(func=lambda args: generate_paid_and_free(args.filename))

# Subcommand: clean_qr_code
parser_clean = subparsers.add_parser("clean", help="Clean QR code using the palette.")
parser_clean.add_argument("filename", type=str, help="Path to the QR code image.")
parser_clean.set_defaults(func=lambda args: clean_qr_code(args.filename))

# Subcommand: normal mode with index
parser_normal = subparsers.add_parser("normal", help="Extract colors using a specific palette index.")
parser_normal.add_argument("image_path", type=str, help="Path to the input image.")
parser_normal.add_argument("colors_file", type=str, help="Path to the colors file.")
parser_normal.add_argument("output_path", type=str, help="Path to save the output image.")
parser_normal.add_argument("index", type=int, help="Index of the color palette to use from the colors file. -1 for all colors.")
parser_normal.set_defaults(func=lambda args: normal_mode(args.image_path, args.colors_file, args.output_path, args.index))

# Parse arguments and execute the appropriate function
args = parser.parse_args()
if args.command:
    args.func(args)
else:
    parser.print_help()