from PIL import Image
import csv

from math import ceil
def rjust_f(s, width):
    return s.rjust(width)
def ljust_f(s, width):
    return s.ljust(width)
# def just_f_factory(just,width):
#     return lambda s: just(s,width)
def repeat_to_length(string, target_length):
    """Repeats the given string to reach the target length."""
    # Calculate how many full repetitions we need
    repetitions = ceil(target_length / len(string))
    return (string * repetitions)[:target_length]

def spreadsheet_format(L,sep="|",minjust=0,justf=rjust_f):
    if len(L)==0:
        return ""
    maxlen=max(len(i) for i in L if type(i)==list)
    indx_to_l=[]
    for i in range(len(L)):
        if type(L[i])==list:
            indx_to_l.append(i)
    l_to_indx={v:k for k,v in enumerate(indx_to_l)}
    strs=[[(str(i) if i!=None else "") for i in j] for j in L if type(j)==list]
    strs=[strs[i]+[""]*(maxlen-len(strs[i])) for i in range(len(strs))]
    widths=[max(len(strs[j][i]) for j in range(len(strs))) for i in range(maxlen)]
    widths=[max(widths[i],minjust) for i in range(len(widths))]
    strs=[[justf(strs[j][i],widths[i]) for i in range(maxlen)] for j in range(len(strs))]
    seplen=len(sep)
    rowlen=sum(widths)+seplen*(maxlen-1)
    return "\n".join([(sep.join(strs[l_to_indx[i]]) if i in l_to_indx else repeat_to_length(L[i],rowlen)) for i in range(len(L))])


def analyze_color_usage(image_path, color_info_path):
    try:
        # Load the image
        img = Image.open(image_path).convert("RGBA")  # Ensure image has an alpha channel
        width, height = img.size

        # Read the color information from the text file
        color_info = []
        with open(color_info_path, "r") as file:
            reader = csv.reader(file)
            for row in reader:
                if len(row) == 3:
                    hex_color, name, is_paid = row
                    color_info.append({
                        "hex_color": hex_color.strip().upper(),
                        "name": name.strip(),
                        "is_paid": int(is_paid.strip()),
                        "count": 0
                    })

        # Create a dictionary for quick lookup of colors
        color_dict = {color["hex_color"]: color for color in color_info}
        num_transparent = 0
        num_matched = 0
        num_unmatched = 0
        num_partial_transparent = 0
        # Iterate through all pixels in the image
        for y in range(height):
            for x in range(width):
                r, g, b, a = img.getpixel((x, y))
                if a != 0:  # Only consider non-transparent pixels
                    if a < 255:
                        num_partial_transparent += 1
                    hex_color = f"FF{r:02X}{g:02X}{b:02X}"  # Convert RGB to hex with "FF" alpha
                    if hex_color in color_dict:
                        color_dict[hex_color]["count"] += 1
                        num_matched += 1
                    else:
                        num_unmatched += 1
                else:
                    num_transparent += 1

        # Separate free and paid colors
        free_colors = [color for color in color_info if color["is_paid"] == 0]
        paid_colors = [color for color in color_info if color["is_paid"] == 1]

        # Print free colors first
        print("Free Colors:")
        # for color in free_colors:
        #     print(f"{color['count']} {color['name']} ({color['hex_color']})")
        print(spreadsheet_format(
            [["count", "name", "hex_color"]] +
            [[color["count"], color['name'], color['hex_color'][2:]] for color in free_colors],justf=ljust_f))
        s=sum(color['count'] for color in free_colors)
        print(f"Total: {s} pixels ({s/num_matched*100 if num_matched > 0 else 0:.2f}%)")
        # Print paid colors next
        print("\nPaid Colors:")
        # for color in paid_colors:
        #     print(f"{color['name']} ({color['hex_color']}): {color['count']} pixels")
        print(spreadsheet_format(
            [["count", "name", "hex_color"]]+
            [[color["count"], color['name'], color['hex_color'][2:]] for color in paid_colors],justf=ljust_f))
        s2=sum(color['count'] for color in paid_colors)
        print(f"Total: {sum(color['count'] for color in paid_colors)} pixels matched ({s2/num_matched*100 if num_matched > 0 else 0:.2f}%)")
        print(f"Total unique paid colors used: {len([color for color in paid_colors if color['count'] > 0])}")
        print("\nSummary:")
        print(f"Total pixels matched: {num_matched}")
        print(f"Total pixels unmatched: {num_unmatched}")
        print(f"Total transparent pixels: {num_transparent}")
        assert width*height == num_matched + num_unmatched + num_transparent, f"Pixel count mismatch! Expected {width*height}, got {num_matched + num_unmatched + num_transparent}."
        print(f"Total pixels in image: {width * height}")
        if num_partial_transparent > 0:
            print(f"WARNING: Found {num_partial_transparent} pixels with partial transparency. This may affect color matching.")

    except Exception as e:
        print(f"Error: {e}")

# Example usage
# image_path = "shijima.png"  # Replace with your image file path
import argparse
parser = argparse.ArgumentParser(description="Analyze color usage in an image based on a color information file.")
parser.add_argument("image_path", type=str, help="Path to the image file.")

args = parser.parse_args()
image_path = args.image_path  # Get the image path from command line arguments

color_info_path = "color_info.txt"  # Replace with your color info file path
analyze_color_usage(image_path, color_info_path)