import collections as col


with open("FT.txt", "r") as FT:
    table = FT.read().split("\n")
with open("SD_weaker_new.gff", "r") as sd:
    found = sd.read().split("\n")[ : -1]


CDS = col.defaultdict(list)
SD = col.defaultdict(list)
result = []

for row in table:
    if row[ : 3] == "CDS":
        info = row.split("\t")
        prot = int(info[1] == "with_protein")
        chr = info[6]
        start = int(info[7])
        end = int(info[8])
        strand = info[9]
        CDS[(chr, strand)].append((start, end, prot))
    else:
        continue

for f in found:
    info = f.split("\t")
    chr = info[0]
    start = int(info[3])
    end = int(info[4])
    strand = info[6]
    SD[(chr, strand)].append((start, end))

for id in CDS.keys():
    CDS_list = CDS[id]
    SD_list = SD[id]
    if id[1] == "+":
        CDS_list.sort(key = lambda x: x[0])
        SD_list.sort(key = lambda x: x[1])
        CDS_iter = iter(CDS_list)
        SD_iter = iter(SD_list)
        cds = next(CDS_iter)
        if cds[0] <= SD_list[0][1]: result.append(id + ("None", "None") + cds + ("None", "None"))
        for sd in SD_iter:
            try:
                diff = cds[0] - sd[1] - 1
                while diff < 0:
                    cds = next(CDS_iter)
                    diff = cds[0] - sd[1] - 1
                    if diff < 0: result.append(id + ("None", "None") + cds + ("None", "None"))
                lkhd = int(diff < 15)
                result.append(id + sd + cds + (diff, lkhd))
            except StopIteration:
                result.append(id + sd + ("None", "None", "None", "None", "None"))
                for sd in SD_iter:
                    result.append(id + sd + ("None", "None", "None", "None", "None"))
                break
        for cds in CDS_iter:
            result.append(id + ("None", "None") + cds + ("None", "None"))
    else:
        CDS_list.sort(key = lambda x: x[1], reverse  = True)
        SD_list.sort(key = lambda x: x[0], reverse  = True)
        CDS_iter = iter(CDS_list)
        SD_iter = iter(SD_list)
        cds = next(CDS_iter)
        if cds[1] >= SD_list[0][0]: result.append(id + ("None", "None") + cds + ("None", "None"))
        for sd in SD_iter:
            try:
                diff = sd[0] - cds[1] - 1
                while diff < 0:
                    cds = next(CDS_iter)
                    diff = sd[0] - cds[1] - 1
                    if diff < 0: result.append(id + ("None", "None") + cds + ("None", "None"))
                lkhd = int(diff < 15)
                result.append(id + sd + cds + (diff, lkhd))
            except StopIteration:
                result.append(id + sd + ("None", "None", "None", "None", "None"))
                for sd in SD_iter:
                    result.append(id + sd + ("None", "None", "None", "None", "None"))
                break
        for cds in CDS_iter:
            result.append(id + ("None", "None") + cds + ("None", "None"))
result = sorted(result, key = lambda x: (x[0], x[1], x[5] == "None", x[5] if x[5] != "None" else 0))

with open("SD_weaker_result.tsv", "w") as outfile:
    outfile.write("Replicon\tStrand\tSD start\tSD end\tCDS start\tCDS end\tProtein\tSpacing\tLikelihood\n")
    for l in result:
        line = "\t".join(map(str, l))
        outfile.write(line + "\n")