#!/usr/bin/env python3
import argparse
import csv
import sys


GAP_SYMBOLS = {"-", "."}


def read_fasta_alignment(filename):
    alignment = {}
    current_name = None
    current_seq = []

    with open(filename, "r", encoding="utf-8") as file:
        for line in file:
            line = line.strip()

            if not line:
                continue

            if line.startswith(">"):
                if current_name is not None:
                    if current_name in alignment:
                        raise ValueError(f"Duplicate sequence name in {filename}: {current_name}")
                    alignment[current_name] = "".join(current_seq)

                current_name = line[1:].split()[0]
                current_seq = []
            else:
                if current_name is None:
                    raise ValueError(f"Sequence without header in {filename}")
                current_seq.append(line)

        if current_name is not None:
            if current_name in alignment:
                raise ValueError(f"Duplicate sequence name in {filename}: {current_name}")
            alignment[current_name] = "".join(current_seq)

    if not alignment:
        raise ValueError(f"No sequences found in {filename}")

    return alignment


def check_alignment(alignment, filename):
    lengths = {len(seq) for seq in alignment.values()}

    if len(lengths) != 1:
        raise ValueError(f"Sequences have different lengths in {filename}")

    return lengths.pop()


def remove_gaps(sequence):
    return "".join(char for char in sequence if char not in GAP_SYMBOLS)


def check_same_sequences(aln1, aln2):
    names1 = set(aln1.keys())
    names2 = set(aln2.keys())

    if names1 != names2:
        only_1 = sorted(names1 - names2)
        only_2 = sorted(names2 - names1)
        raise ValueError(
            "Different sequence names in input files. "
            f"Only in file 1: {only_1}; only in file 2: {only_2}"
        )

    for name in names1:
        seq1 = remove_gaps(aln1[name])
        seq2 = remove_gaps(aln2[name])

        if seq1 != seq2:
            raise ValueError(
                f"Ungapped sequences are different for {name}. "
                "The program can compare only alignments of the same sequences."
            )

    return sorted(names1)


def get_column_signatures(alignment, names):
    alignment_length = len(next(iter(alignment.values())))
    residue_numbers = {name: 0 for name in names}
    signatures = []

    for column_index in range(alignment_length):
        signature = []

        for name in names:
            char = alignment[name][column_index]

            if char in GAP_SYMBOLS:
                signature.append("-")
            else:
                residue_numbers[name] += 1
                signature.append(residue_numbers[name])

        signatures.append(tuple(signature))

    return signatures


def find_equal_columns(signatures1, signatures2):
    matches = []

    signature_to_columns2 = {}

    for j, signature in enumerate(signatures2, start=1):
        if all(value == "-" for value in signature):
            continue
        signature_to_columns2.setdefault(signature, []).append(j)

    for i, signature in enumerate(signatures1, start=1):
        if all(value == "-" for value in signature):
            continue

        if signature in signature_to_columns2:
            for j in signature_to_columns2[signature]:
                matches.append((i, j))

    return matches


def find_blocks(matches, min_block_length):
    blocks = []
    single_columns = []

    if not matches:
        return blocks, single_columns

    matches = sorted(matches)
    current_block = [matches[0]]

    for pair in matches[1:]:
        previous = current_block[-1]

        if pair[0] == previous[0] + 1 and pair[1] == previous[1] + 1:
            current_block.append(pair)
        else:
            if len(current_block) >= min_block_length:
                blocks.append(current_block)
            else:
                single_columns.extend(current_block)

            current_block = [pair]

    if len(current_block) >= min_block_length:
        blocks.append(current_block)
    else:
        single_columns.extend(current_block)

    return blocks, single_columns


def format_pairs(pairs):
    if not pairs:
        return "-"

    return "; ".join(f"({i},{j})" for i, j in pairs)


def format_blocks(blocks):
    if not blocks:
        return "-"

    formatted_blocks = []

    for block in blocks:
        start1, start2 = block[0]
        end1, end2 = block[-1]
        length = len(block)

        formatted_blocks.append(
            f"({start1},{start2})-({end1},{end2}), length={length}"
        )

    return "; ".join(formatted_blocks)


def write_matches(output_file, matches):
    with open(output_file, "w", newline="", encoding="utf-8") as file:
        writer = csv.writer(file, delimiter="\t")
        writer.writerow(["column_in_alignment_1", "column_in_alignment_2"])

        for i, j in matches:
            writer.writerow([i, j])


def compare_alignments(file1, file2, output_file, min_block_length):
    if min_block_length < 1:
        raise ValueError("Minimal block length must be at least 1")

    aln1 = read_fasta_alignment(file1)
    aln2 = read_fasta_alignment(file2)

    length1 = check_alignment(aln1, file1)
    length2 = check_alignment(aln2, file2)

    names = check_same_sequences(aln1, aln2)

    signatures1 = get_column_signatures(aln1, names)
    signatures2 = get_column_signatures(aln2, names)

    matches = find_equal_columns(signatures1, signatures2)
    blocks, single_columns = find_blocks(matches, min_block_length)

    percent1 = round(len(matches) / length1 * 100, 2)
    percent2 = round(len(matches) / length2 * 100, 2)

    summary_rows = [
        ["parameter", "value"],
        ["file_1", file1],
        ["file_2", file2],
        ["number_of_sequences", len(names)],
        ["alignment_1_length", length1],
        ["alignment_2_length", length2],
        ["equal_columns_number", len(matches)],
        ["equal_columns_percent_from_alignment_1", percent1],
        ["equal_columns_percent_from_alignment_2", percent2],
        ["minimal_block_length", min_block_length],
        ["equal_blocks", format_blocks(blocks)],
        ["equal_columns_not_in_blocks", format_pairs(single_columns)],
    ]

    writer = csv.writer(sys.stdout, delimiter="\t")
    writer.writerows(summary_rows)

    if output_file:
        write_matches(output_file, matches)


def main():
    parser = argparse.ArgumentParser(
        description="Compare two multiple sequence alignments in FASTA format"
    )

    parser.add_argument("file1", help="First alignment in FASTA format")
    parser.add_argument("file2", help="Second alignment in FASTA format")
    parser.add_argument(
        "-o",
        "--output",
        help="Output TSV file with pairs of equally aligned columns"
    )
    parser.add_argument(
        "--min-block-length",
        type=int,
        default=2,
        help="Minimal length of a block of consecutive equal columns"
    )

    args = parser.parse_args()

    try:
        compare_alignments(
            args.file1,
            args.file2,
            args.output,
            args.min_block_length
        )
    except Exception as error:
        print(f"Error: {error}", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    main()
