// =============================================================================
// CD-HIT
// http://bioinformatics.burnham-inst.org/cd-hi
// 
// program written by
//                                      Weizhong Li
//                                      UCSD, San Diego Supercomputer Center
//                                      La Jolla, CA, 92093
//                                      Email liwz@sdsc.edu
//                 at
//                                      Adam Godzik's lab
//                                      The Burnham Institute
//                                      La Jolla, CA, 92037
//                                      Email adam@burnham-inst.org
// =============================================================================

#include "cd-hi.h"
#include "cd-hi-init.h"

////////////////////////////////////  MAIN /////////////////////////////////////
int main(int argc, char **argv) {
  int i, j, k, i1, j1, k1, i0, j0, k0, sggi, sgj;
  int si, sj, sk;
  char db_in[MAX_FILE_NAME];
  char db_out[MAX_FILE_NAME];
  char db_clstr[MAX_FILE_NAME];
  char db_clstr_bak[MAX_FILE_NAME];
  char db_clstr_old[MAX_FILE_NAME];

  // ***********************************    parse command line and open file
  if (argc < 5) print_usage(argv[0]);

  for (i=1; i<argc; i++) {
    if      (strcmp(argv[i], "-i") == 0)
      strncpy(db_in, argv[++i], MAX_FILE_NAME-1);
    else if (strcmp(argv[i], "-o") == 0)
      strncpy(db_out, argv[++i], MAX_FILE_NAME-1);
    else if (strcmp(argv[i], "-u") == 0) {
      strncpy(db_clstr_old, argv[++i], MAX_FILE_NAME-1);
      old_clstr_file = 1;
    }
    else if (strcmp(argv[i], "-M") == 0) 
      mem_limit = 1000000 * atoi(argv[++i]);
    else if (strcmp(argv[i], "-l") == 0)
      length_of_throw = atoi(argv[++i]);
    else if (strcmp(argv[i], "-c") == 0) {
      NR_clstr = atof(argv[++i]);
      if ((NR_clstr > 1.0) || (NR_clstr < 0.4)) bomb_error("invalid clstr");
      NR_clstr100 = (int) (NR_clstr * 100 );
    }
    else if (strcmp(argv[i], "-L") == 0) {
      NR_cov   = atof(argv[++i]);
      if ((NR_cov > 1.0)   || (NR_cov < 0.0)) bomb_error("invalid coverage cutoff");
    }
    else if (strcmp(argv[i], "-b") == 0) {
      BAND_width = atoi(argv[++i]);
      if (BAND_width < 0 ) bomb_error("invalid band width");
    }
    else if (strcmp(argv[i], "-n") == 0) {
      NAA = atoi(argv[++i]);
      if ( NAA < 2 || NAA > 5 ) bomb_error("invalid word length");
    }
    else if (strcmp(argv[i], "-d") == 0) {
      des_len = atoi(argv[++i]);
      if ( des_len < 15 ) 
        bomb_error("too short description, not enough to identify sequences");
    }
    else if (strcmp(argv[i], "-t") == 0) {
      tolerance = atoi(argv[++i]);
      if ( tolerance < 0 || tolerance > 5 ) bomb_error("invalid tolerance");
    }
    else 
      print_usage(argv[0]);
  }
  db_clstr[0]=0; strcat(db_clstr,db_out); strcat(db_clstr,".clstr");
  db_clstr_bak[0]=0; strcat(db_clstr_bak,db_out); strcat(db_clstr_bak,".bak.clstr");

  if      ( NAA == 2 ) { NAAN = NAA2; }
  else if ( NAA == 3 ) { NAAN = NAA3; }
  else if ( NAA == 4 ) { NAAN = NAA4; }
  else if ( NAA == 5 ) { NAAN = NAA5; }
  else bomb_error("invalid -n parameter!");

  word_table.init(NAA, NAAN);

  if ( tolerance ) {
    int clstr_idx = (int) (NR_clstr * 100) - naa_stat_start_percent;
    int tcutoff = naa_stat[tolerance-1][clstr_idx][5-NAA];

    if (tcutoff < 5 )
      bomb_error("Too short word length, increase it or the tolerance");
    for ( i=5; i>NAA; i--) {
      if ( naa_stat[tolerance-1][clstr_idx][5-i] > 10 ) {
        cout << "Your word length is " << NAA << ", using "
             << i << " may be faster!" <<endl;
        break;
      }
    }
  }
  else {
   if      ( NR_clstr > 0.85 && NAA < 5)
      cout << "Your word length is " << NAA 
           << ", using 5 may be faster!" <<endl;
    else if ( NR_clstr > 0.80 && NAA < 4 )
      cout << "Your word length is " << NAA 
           << ", using 4 may be faster!" <<endl;
    else if ( NR_clstr > 0.75 && NAA < 3 )
      cout << "Your word length is " << NAA 
           << ", using 3 may be faster!" <<endl;
  }

  if ( length_of_throw <= NAA ) bomb_error("Too short -l, redefine it");

  ifstream in1(db_in);
  if ( ! in1 ) { cout << "Can not open file" << db_in << endl; exit(1); }
  ofstream out1(db_out);
  if ( ! out1) { cout << "Can not open file" << db_out << endl; exit(1); }
  ofstream out2(db_clstr);
  if ( ! out2) { cout << "Can not open file" << db_clstr << endl; exit(1); }
  ofstream out2_bak(db_clstr_bak);
  if ( ! out2_bak) { cout << "Can not open file" << db_clstr_bak << endl; exit(1); }

  DB_no = db_seq_no_test(in1); in1.open(db_in);
  if ((NR_len      = new int   [DB_no]) == NULL) bomb_error("Memory");
  if ((NR_idx      = new int   [DB_no]) == NULL) bomb_error("Memory");
  if ((NR90_idx    = new int   [DB_no]) == NULL) bomb_error("Memory");
  if ((NR_clstr_no = new int   [DB_no]) == NULL) bomb_error("Memory");
  if ((NR_iden     = new char  [DB_no]) == NULL) bomb_error("Memory");
  if ((NR_coverage = new char  [DB_no]) == NULL) bomb_error("Memory");
  if ((NR_flag     = new char  [DB_no]) == NULL) bomb_error("Memory");
  if ((NR_seq      = new char *[DB_no]) == NULL) bomb_error("Memory");
  int *Clstr_no, *(*Clstr_list);
  if ((Clstr_no    = new int   [DB_no]) == NULL) bomb_error("Memory");
  if ((Clstr_list  = new int  *[DB_no]) == NULL) bomb_error("Memory");
  if ((NR90f_idx   = new int   [DB_no]) == NULL) bomb_error("Memory");

  if ( old_clstr_file ) {
    ifstream in_clstr(db_clstr_old);
    if ( ! in_clstr) { 
      cout << "Can not open file" << db_clstr_old << endl; 
      exit(1); 
    }

    //number of seq in old clstr file
    int clstr_seq_no = old_clstr_seq_no_test(in_clstr);
    in_clstr.open(db_clstr_old);

    if ((NRo_idx       = new int [clstr_seq_no]) == NULL) bomb_error("Memory");
    if ((NRo_id1       = new int [clstr_seq_no]) == NULL) bomb_error("Memory");
    if ((NRo_id2       = new int [clstr_seq_no]) == NULL) bomb_error("Memory");
    if ((NRo_clstr_no  = new int [clstr_seq_no]) == NULL) bomb_error("Memory");
    if ((NRo_NR_idx    = new int [clstr_seq_no]) == NULL) bomb_error("Memory");
    if ((NRo_iden      = new char[clstr_seq_no]) == NULL) bomb_error("Memory");
    old_clstr_read_in(in_clstr, NRo_no, NRo90_no, NRo_idx, NRo_id1, NRo_id2,
                      NRo_iden, NRo_clstr_no, NRo_NR_idx);
    in_clstr.close();
  }
 
  NRo_no > 0 ?
    db_read_in2(in1, length_of_throw, NR_no, NR_seq, NR_len,
                NRo_no, NRo_idx, NRo_id1, NRo_id2, NRo_NR_idx):
    db_read_in(in1, length_of_throw, NR_no, NR_seq, NR_len);
  in1.close();
  cout << "total seq: " << NR_no << endl;

  // ********************************************* init NR_flag
  for (i=0; i<NR_no; i++) NR_flag[i] = 0;
  if ( old_clstr_file ) {
    for (i=0; i<NRo_no; i++) {
      if ( (j = NRo_NR_idx[i]) == -1 ) continue ;
      if ( NRo_clstr_no[i] == i ) NR_flag[j] |= IS_OLD_REP;

      if ( (k = NRo_NR_idx[ NRo_clstr_no[i] ]) == -1 ) continue ;
      if ( NRo_clstr_no[i] != i ) NR_flag[j] |= IS_OLD_REDUNDANT;
      NR_iden[j] = NRo_iden[i];
      NR_clstr_no[j] = k; // note, later it need be changed to NR90_no
    }
    delete [] NRo_idx;
    delete [] NRo_id1;
    delete [] NRo_id2;
    delete [] NRo_iden;
    delete [] NRo_clstr_no;
    delete [] NRo_NR_idx;
  }

  sort_seqs_divide_segs(NR_no, NR_len, NR_idx, NR_seq, mem_limit, NAAN,
                        SEG_no, SEG_b, SEG_e, db_swap);

  // *********************************************                Main loop
  char *seqi;
  double aa1_cutoff = NR_clstr;
  double aa2_cutoff = 1 - (1-NR_clstr)*2;
  double aan_cutoff = 1 - (1-NR_clstr)*NAA;
  int len, hit_no, has_aa2, iden_no, aan_no, segb;
  int aan_list[MAX_SEQ];
  INTs aan_list_no[MAX_SEQ];
  int frg1, frg2, segfb;
  int aan_list_backup[MAX_SEQ];
  INTs *look_and_count;
  NR_frag_no = 0; 
  for (i=0; i<NR_no; i++) NR_frag_no += (NR_len[i] - NAA ) / Frag_size + 1;
  if ((look_and_count= new INTs[NR_frag_no]) == NULL) bomb_error("Memory");

  if ( tolerance ) {
    int clstr_idx = (int) (NR_clstr * 100) - naa_stat_start_percent;
    double d2  = ((double) (naa_stat[tolerance-1][clstr_idx][3]     )) / 100;
    double dn  = ((double) (naa_stat[tolerance-1][clstr_idx][5-NAA] )) / 100;
    aa2_cutoff = d2 > aa2_cutoff ? d2 : aa2_cutoff;
    aan_cutoff = dn > aan_cutoff ? dn : aan_cutoff;
  }

  NR90_no = 0; NR90f_no = 0;
  for (sggi=0; sggi<SEG_no; sggi++) {
    if (SEG_no >1)
      cout << "SEG " << sggi << " " << SEG_b[sggi] << " " << SEG_e[sggi] <<endl;

    for (sgj=sggi-1; sgj>=0; sgj--) {
      cout << "Reading swap" << endl;
      if ( sgj != sggi-1) word_table.read_tbl(db_swap[sgj]);    // reading old segment
      cout << "Comparing with SEG " << sgj << endl;
      for (i1=SEG_b[sggi]; i1<=SEG_e[sggi]; i1++) {
        i = NR_idx[i1];
        if (NR_flag[i] & IS_REDUNDANT ) continue;

        if ( (NR_flag[i] & IS_OLD_REDUNDANT) &&
             (NR_flag[ NR_clstr_no[i] ] & IS_REP) ) {
          NR_clstr_no[i] = - (NR_clstr_no[ NR_clstr_no[i] ]) - 1;
          NR_flag[i] |= IS_REDUNDANT ;
          delete [] NR_seq[i];
          continue;
        }

        len = NR_len[i]; seqi = NR_seq[i];
        frg1 = (len - NAA ) / Frag_size + 1;
        frg2 = (len - NAA + BAND_width ) / Frag_size + 1;
        has_aa2 = 0;

        int flag = check_this_short(len, seqi, has_aa2,
               NAA, aan_no, aan_list, aan_list_no,
                            aan_list_backup, look_and_count,
               hit_no, SEG90_b[sgj], SEG90_e[sgj],
               frg2, SEG90f_b[sgj], SEG90f_e[sgj], iden_no,
               aa1_cutoff, aa2_cutoff, aan_cutoff,
               NR_flag[i], NR_flag) ;

        if ( flag == 1) {       // if similar to old one delete it
          delete [] NR_seq[i];
          NR_clstr_no[i] = -hit_no-1;  // (-hit_no-1) for non representatives
          NR_iden[i] = iden_no * 100 / len;
          NR_flag[i] |= IS_REDUNDANT ;
        }
      } //for (i1=SEG_b[sggi]; i1<=SEG_e[sggi]; i1++)
    } // for (sgj=0; sgj<sggi; sgj++)

    if (SEG_no >1) cout << "Refresh Memory" << endl;
    word_table.clean();

    if (SEG_no >1) cout << "Self comparing" << endl;
    segb = NR90_no;
    segfb = NR90f_no;
    for (i1=SEG_b[sggi]; i1<=SEG_e[sggi]; i1++) {
      i = NR_idx[i1];

      if ( ! (NR_flag[i] & IS_REDUNDANT) ) {
        if ( (NR_flag[i] & IS_OLD_REDUNDANT) &&
             (NR_flag[ NR_clstr_no[i] ] & IS_REP) ) {
          NR_clstr_no[i] = - (NR_clstr_no[ NR_clstr_no[i] ]) - 1;
          NR_flag[i] |= IS_REDUNDANT ;
          delete [] NR_seq[i];
        }
        else {
          len = NR_len[i]; seqi = NR_seq[i];
          frg1 = (len - NAA ) / Frag_size + 1;
          frg2 = (len - NAA + BAND_width ) / Frag_size + 1;
          has_aa2 = 0;
    
          int flag = check_this_short(len, seqi, has_aa2,
                 NAA, aan_no, aan_list, aan_list_no,
                              aan_list_backup, look_and_count,
                 hit_no, segb, NR90_no-1, frg2, segfb, NR90f_no-1, iden_no,
                 aa1_cutoff, aa2_cutoff, aan_cutoff,
                 NR_flag[i], NR_flag);
    
          if ( flag == 1) {       // if similar to old one delete it
            delete [] NR_seq[i];
            NR_clstr_no[i] = -hit_no-1;  // (-hit_no-1) for non representatives
            NR_iden[i] = iden_no * 100 / len;
          }
          else {                  // else add to NR90 db
            NR90_idx[NR90_no] = i;
            NR_clstr_no[i] = NR90_no; // positive value for representatives
            NR_iden[i] = 0;
            NR_flag[i] |= IS_REP;
            add_in_lookup_table_short(aan_no, frg1, aan_list_backup,
                                      aan_list_no);
            NR90f_idx[NR90_no] = NR90f_no;
            NR90f_no += frg1;
            NR90_no++;
          } // else
        } // else
      } // if ( ! (NR_flag[i] & IS_REDUNDANT) )
  
      if ( (i1+1) % 100 == 0 ) { 
        cerr << ".";
        if ( (i1+1) % 1000 == 0 )
          cout << i1+1 << " finished\t" << NR90_no << " clusters" << endl;
      }
    } // for (i1=SEG_b[sggi]; i1<=SEG_e[sggi]; i1++) {

    SEG90_b[sggi] = segb;  SEG90_e[sggi] = NR90_no-1;
    SEG90f_b[sggi] = segfb; SEG90f_e[sggi] = NR90f_no-1;
    if ( sggi < SEG_no-2 ) word_table.write_tbl( db_swap[sggi] ); // if not last segment
  } // for (sggi=0; sggi<SEG_no; sggi++) {
  cout << NR_no << " finished\t" << NR90_no << " clusters" << endl;

  for (i=0; i<NR90_no; i++)  delete [] NR_seq[ NR90_idx[i] ]; 

  cout << "writing new database" << endl;
  in1.open(db_in);
  db_read_and_write(in1, out1, length_of_throw, des_len, NR_seq, NR_clstr_no);
  in1.close(); out1.close();

  // write a backup clstr file in case next step crashes
  for (i=0; i<NR_no; i++) {
    j1 = NR_clstr_no[i];
    if ( j1 < 0 ) j1 =-j1-1;
    out2_bak << j1 << "\t" << NR_len[i] << "aa, "<< NR_seq[i] << "...";
    if ( NR_iden[i]>0 ) out2_bak << " at " << int(NR_iden[i]) << "%" << endl;
    else                out2_bak << " *" << endl;
  }
  out2_bak.close();

  cout << "writing clustering information" << endl;
  // write clstr information
//  I mask following 3 lines, because it crash when clusters NR
//  I thought maybe there is not a big block memory now, so
//  move the new statement to the begining of program, but because I
//  don't know the NR90_no, I just use DB_no instead
//  int *Clstr_no, *(*Clstr_list);
//  if ((Clstr_no   = new int[NR90_no]) == NULL) bomb_error("Memory");
//  if ((Clstr_list = new int*[NR90_no]) == NULL) bomb_error("Memory");


  for (i=0; i<NR90_no; i++) Clstr_no[i]=0;
  for (i=0; i<NR_no; i++) {
    j1 = NR_clstr_no[i];
    if ( j1 < 0 ) j1 =-j1-1;
    Clstr_no[j1]++;
  }
  for (i=0; i<NR90_no; i++) {
    if((Clstr_list[i] = new int[ Clstr_no[i] ]) == NULL) bomb_error("Memory");
    Clstr_no[i]=0;
  }

  for (i=0; i<NR_no; i++) {
    j1 = NR_clstr_no[i];
    if ( j1 < 0 ) j1 =-j1-1;
    Clstr_list[j1][ Clstr_no[j1]++ ] = i;
  }

  for (i=0; i<NR90_no; i++) {
    out2 << ">Cluster " << i << endl;
    for (k=0; k<Clstr_no[i]; k++) {
      j = Clstr_list[i][k];
      out2 << k << "\t" << NR_len[j] << "aa, "<< NR_seq[j] << "...";
      if ( NR_iden[j]>0 ) out2 << " at " << int(NR_iden[j]) << "%" << endl;
      else                  out2 << " *" << endl;
    }
  }
  out2.close();
  cout << "program completed !" << endl << endl;

} // END int main

///////////////////////FUNCTION of common tools////////////////////////////

int check_this_short(int len, char *seqi, int &has_aa2,
               int NAA, int& aan_no, int *aan_list, INTs *aan_list_no,
                                     int *aan_list_backup,
               INTs *look_and_count,
               int &hit_no, int libb, int libe, 
               int frg2, int libfb, int libfe, int &iden_no,
               double aa1_cutoff, double aa2_cutoff, double aan_cutoff,
               char this_flag, char *NR_flag) {

  static int  taap[MAX_UAA*MAX_UAA];
  static INTs aap_list[MAX_SEQ];
  static INTs aap_begin[MAX_UAA*MAX_UAA];

  int i, j, k, i1, j1, k1, i0, j0, k0, c22, sk, mm, fn;
  int required_aa1 = int (aa1_cutoff* (double) len);
  int required_aa2 = int (aa2_cutoff* (double) len);
  int required_aan = int (aan_cutoff* (double) len);

  aan_no = len - NAA + 1;
  if      ( NAA == 2)
    for (j=0; j<aan_no; j++)
      aan_list_backup[j] = aan_list[j] = seqi[j]*NAA1 + seqi[j+1];
  else if ( NAA == 3)
    for (j=0; j<aan_no; j++)
      aan_list_backup[j] = aan_list[j] = 
        seqi[j]*NAA2 + seqi[j+1]*NAA1 + seqi[j+2];
 else if ( NAA == 4)
    for (j=0; j<aan_no; j++)
      aan_list_backup[j] = aan_list[j] =
        seqi[j]*NAA3+seqi[j+1]*NAA2 + seqi[j+2]*NAA1 + seqi[j+3];
  else if ( NAA == 5)
    for (j=0; j<aan_no; j++)
      aan_list_backup[j] = aan_list[j] =
        seqi[j]*NAA4+seqi[j+1]*NAA3+seqi[j+2]*NAA2+seqi[j+3]*NAA1+seqi[j+4];

  else return FAILED_FUNC;
  
  quick_sort(aan_list,0,aan_no-1);
  for(j=0; j<aan_no; j++) aan_list_no[j]=1;
  for(j=aan_no-1; j; j--) {
    if (aan_list[j] == aan_list[j-1]) {
      aan_list_no[j-1] += aan_list_no[j];
      aan_list_no[j]=0;
    }
  }
  // END check_aan_list


  // lookup_aan
  for (j=libfe; j>=libfb; j--) look_and_count[j]=0;
  word_table.count_word_no(aan_no, aan_list, aan_list_no, look_and_count);


  // contained_in_old_lib()
  int band_left, band_right, best_score, band_width1, best_sum, len2, best1,sum;
  int len1 = len - 1;
  INTs *lookptr;

  char *seqj;
  int flag = 0;      // compare to old lib
  for (j=libe; j>=libb; j--) {
    if ( (this_flag & IS_OLD_REP ) &&
         (NR_flag[NR90_idx[j]] & IS_OLD_REP) ) continue;
    len2 = NR_len[NR90_idx[j]];

    k = (len2 - NAA) / Frag_size + 1;
    lookptr = &look_and_count[ NR90f_idx[j] ];

    if ( frg2 >= k ) {
      best1=0;
      for (j1=0; j1<k; j1++) best1 += lookptr[j1];
    }
    else {
      sum = 0;
      for (j1=0; j1<frg2; j1++) sum += lookptr[j1];
      best1 = sum;
      for (j1=frg2; j1<k; j1++) {
        sum += lookptr[j1] - lookptr[j1-frg2];
        if (sum > best1) best1 = sum;
      }
    }

    if ( best1 < required_aan ) continue;

    seqj = NR_seq[NR90_idx[j]];
    
    if ( has_aa2 == 0 )  { // calculate AAP array
      for (sk=0; sk<NAA2; sk++) taap[sk] = 0;
      for (j1=0; j1<len1; j1++) {
        c22= seqi[j1]*NAA1 + seqi[j1+1]; 
        taap[c22]++;
      }
      for (sk=0,mm=0; sk<NAA2; sk++) { 
        aap_begin[sk] = mm; mm+=taap[sk]; taap[sk] = 0;
      }
      for (j1=0; j1<len1; j1++) {
        c22= seqi[j1]*NAA1 + seqi[j1+1]; 
        aap_list[aap_begin[c22]+taap[c22]++] =j1;
      }
      has_aa2 = 1;
    }

    band_width1 = (BAND_width < len+len2-2 ) ? BAND_width : len+len2-2;
    diag_test_aapn(seqj, len, len2, taap, aap_begin, 
                           aap_list, best_sum,
                           band_width1, band_left, band_right, required_aa1);
    if ( best_sum < required_aa2 ) continue;
    
    local_band_align(seqi, seqj, len, len2, mat,
                             best_score, iden_no, band_left, band_right);
    if ( iden_no < required_aa1 ) continue;
    if ( (iden_no * 100 / len ) < NR_clstr100 ) continue;

    flag = 1; break; // else flag = 1, and break loop
  }
  hit_no = j;
  return flag;
  // END contained_in_old_lib()
} // END check_this_short


int add_in_lookup_table_short(int aan_no, int frg1,
                              int *aan_list, INTs *aan_list_no) {
  int i, j, k, i1, j1, k1, i0, j0, k0, fra;

  for (i=0; i<frg1; i++) {
    k = (i+1)*Frag_size < aan_no ? (i+1)*Frag_size-1: aan_no-1;
    quick_sort(aan_list, i*Frag_size, k);
  }
  for(j=aan_no-1; j; j--) {
    if (aan_list[j] == aan_list[j-1]) {
      aan_list_no[j-1] += aan_list_no[j];
      aan_list_no[j]=0;
    }
  }
  // END check_aan_list

  for (i=0; i<aan_no; i+=Frag_size) {
    k = Frag_size < (aan_no-i) ? Frag_size : (aan_no -i);
    fra=i/Frag_size;
    word_table.add_word_list(k, aan_list+i, aan_list_no+i, NR90f_no+fra);
  }

  return 0;
}  // END add_in_lookup_table


/////////////////////////// END ALL ////////////////////////