"""
parse_water.py

Parses water alignment files and extracts:
  Score, %Identity, %Similarity, Gaps, Indels, Coverage1, Coverage2.

Coverage = (last_coord - first_coord + 1) / full_sequence_length * 100%

Full sequence lengths are looked up from UniProt TSV files (ecoli.txt, bacsu.txt).

Usage:
    python parse_water.py ecoli.txt bacsu.txt file1.water [file2.water ...]
"""

import sys
import re

def count_indels(sequence):
    return len(re.findall(r'-+', sequence))

def load_lengths(tsv_file):
    """Load dict of entry_name -> length from UniProt TSV."""
    lengths = {}
    with open(tsv_file, encoding='utf-8') as f:
        header = f.readline().strip().split('\t')
        try:
            name_col = header.index('Entry Name')
            len_col  = header.index('Length')
        except ValueError:
            print(f"ERROR: {tsv_file} must have 'Entry Name' and 'Length' columns")
            sys.exit(1)
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) > max(name_col, len_col):
                lengths[parts[name_col].strip()] = int(parts[len_col].strip())
    return lengths

def parse_water(filename, lengths):
    with open(filename) as f:
        content = f.read()
    lines = content.splitlines()

    result = {}

    for line in lines:
        if line.startswith('# 1:'):
            result['id1'] = line.split(':', 1)[1].strip()
        elif line.startswith('# 2:'):
            result['id2'] = line.split(':', 1)[1].strip()
        elif line.startswith('# Score:'):
            result['score'] = line.split(':', 1)[1].strip()
        elif line.startswith('# Identity:'):
            m = re.search(r'\(\s*([\d.]+)%\)', line)
            if m:
                result['identity_pct'] = m.group(1) + '%'
        elif line.startswith('# Similarity:'):
            m = re.search(r'\(\s*([\d.]+)%\)', line)
            if m:
                result['similarity_pct'] = m.group(1) + '%'
        elif line.startswith('# Gaps:'):
            m = re.search(r'(\d+)/\d+', line)
            if m:
                result['gaps'] = m.group(1)

    name1 = result.get('id1', '')
    name2 = result.get('id2', '')
    seq1_parts = []
    seq2_parts = []
    coords1 = []
    coords2 = []

    for line in lines:
        if line.startswith('#') or not line.strip():
            continue
        m = re.match(r'^(\S+)\s+(\d+)\s+([A-Za-z\-]+)\s+(\d+)', line)
        if m:
            sname = m.group(1)
            start = int(m.group(2))
            res   = m.group(3)
            end   = int(m.group(4))
            if sname == name1:
                seq1_parts.append(res)
                coords1.append((start, end))
            elif sname == name2:
                seq2_parts.append(res)
                coords2.append((start, end))

    seq1 = ''.join(seq1_parts)
    seq2 = ''.join(seq2_parts)
    result['total_indels'] = count_indels(seq1) + count_indels(seq2)

    # Coverage
    full_len1 = lengths.get(name1)
    full_len2 = lengths.get(name2)

    if coords1 and full_len1:
        aln_len1 = coords1[-1][1] - coords1[0][0] + 1
        result['coverage1'] = f"{aln_len1 / full_len1 * 100:.1f}%"
    else:
        result['coverage1'] = 'N/A'

    if coords2 and full_len2:
        aln_len2 = coords2[-1][1] - coords2[0][0] + 1
        result['coverage2'] = f"{aln_len2 / full_len2 * 100:.1f}%"
    else:
        result['coverage2'] = 'N/A'

    return result

def main():
    if len(sys.argv) < 4:
        print("Usage: python parse_water.py ecoli.txt bacsu.txt file1.water [file2.water ...]")
        sys.exit(1)

    lengths = {}
    lengths.update(load_lengths(sys.argv[1]))
    lengths.update(load_lengths(sys.argv[2]))

    water_files = sys.argv[3:]

    header = (f"{'ID1':<20} {'ID2':<20} {'Score':>8} {'%Ident':>8} "
              f"{'%Simil':>8} {'Gaps':>6} {'Indels':>7} {'Cov1':>8} {'Cov2':>8}")
    print(header)
    print("-" * len(header))

    for filename in water_files:
        r = parse_water(filename, lengths)
        print(f"{r.get('id1','?'):<20} {r.get('id2','?'):<20} "
              f"{r.get('score','?'):>8} {r.get('identity_pct','?'):>8} "
              f"{r.get('similarity_pct','?'):>8} {r.get('gaps','?'):>6} "
              f"{r.get('total_indels','?'):>7} "
              f"{r.get('coverage1','N/A'):>8} {r.get('coverage2','N/A'):>8}")

if __name__ == "__main__":
    main()
