from sys import argv, exit
def help_msg():
    print("Usage:")
    print("python compare_alignments.py a.fasta b.fasta out.txt")
    print()
    print("Input:")
    print("Two fasta alignments with the same sequence IDs")
    print()
    print("Output:")
    print("Matched columns and blocks")
def read_fasta(f):
    d = {}
    n = ""
    for l in open(f):
        l = l.strip()
        if not l:
            continue
        if l[0] == ">":
            n = l[1:].split()[0]
            n = n.split("/")[0]
            d[n] = ""
        else:
            d[n] += l
    return d
def num(s):
    r = []
    k = 0
    for x in s:
        if x == "-":
            r.append("-")
        else:
            k += 1
            r.append(k)
    return r
if len(argv) == 2 and argv[1] == "-h":
    help_msg()
    exit()
if len(argv) != 4:
    print("Use -h for help")
    exit()
a = read_fasta(argv[1])
b = read_fasta(argv[2])
if set(a) != set(b):
    print("Sequence IDs differ")
    exit()
ids = sorted(set(a))
l1 = len(next(iter(a.values())))
l2 = len(next(iter(b.values())))
for x in ids:
    if len(a[x]) != l1:
        print("Alignment 1 has sequences of different length")
        exit()
    if len(b[x]) != l2:
        print("Alignment 2 has sequences of different length")
        exit()
na = {}
nb = {}
for x in ids:
    na[x] = num(a[x])
    nb[x] = num(b[x])
v1 = []
v2 = []
for i in range(l1):
    v1.append(tuple(na[x][i] for x in ids))
for i in range(l2):
    v2.append(tuple(nb[x][i] for x in ids))
m = []
for i in range(l1):
    for j in range(l2):
        if v1[i] == v2[j]:
            m.append((i + 1, j + 1))
o = open(argv[3], "w")
o.write("Matched columns\n\n")
for x, y in m:
    o.write(str(x) + " " + str(y) + "\n")
b = []
if m:
    s1, s2 = m[0]
    f1, f2 = m[0]

    for x, y in m[1:]:
        if x == f1 + 1 and y == f2 + 1:
            f1 = x
            f2 = y
        else:
            if f1 - s1 + 1 >= 2:
                b.append((s1, f1, s2, f2))
            s1, s2 = x, y
            f1, f2 = x, y
    if f1 - s1 + 1 >= 2:
        b.append((s1, f1, s2, f2))
o.write("\nBlocks\n\n")
for s1, f1, s2, f2 in b:
    o.write("(" + str(s1) + "," + str(f1) + ")=(")
    o.write(str(s2) + "," + str(f2) + ")\n")
o.close()
same = len(m)
print("Alignment 1 length:", l1)
print("Alignment 2 length:", l2)
print("Matched columns:", same)
print("Percent of alignment 1:",
round(same / l1 * 100, 2))
print("Percent of alignment 2:",
round(same / l2 * 100, 2))
print("Blocks:", len(b))