| 
80
 | 
     1 #!/usr/bin/env python
 | 
| 
 | 
     2 #Greg Von Kuster
 | 
| 
 | 
     3 """
 | 
| 
 | 
     4 Calculate correlations between numeric columns in a tab delim file.
 | 
| 
 | 
     5 usage: %prog infile output.txt columns method
 | 
| 
 | 
     6 """
 | 
| 
 | 
     7 
 | 
| 
 | 
     8 import sys
 | 
| 
 | 
     9 #from rpy import *
 | 
| 
 | 
    10 import rpy2.robjects as robjects
 | 
| 
 | 
    11 r = robjects.r
 | 
| 
 | 
    12 
 | 
| 
 | 
    13 def stop_err(msg):
 | 
| 
 | 
    14     sys.stderr.write(msg)
 | 
| 
 | 
    15     sys.exit()
 | 
| 
 | 
    16     
 | 
| 
 | 
    17 def main():
 | 
| 
 | 
    18     method = sys.argv[4]
 | 
| 
 | 
    19     assert method in ( "pearson", "kendall", "spearman" )
 | 
| 
 | 
    20 
 | 
| 
 | 
    21     try:
 | 
| 
 | 
    22         columns = map( int, sys.argv[3].split( ',' ) )
 | 
| 
 | 
    23     except:
 | 
| 
 | 
    24         stop_err( "Problem determining columns, perhaps your query does not contain a column of numerical data." )
 | 
| 
 | 
    25     
 | 
| 
 | 
    26     matrix = []
 | 
| 
 | 
    27     skipped_lines = 0
 | 
| 
 | 
    28     first_invalid_line = 0
 | 
| 
 | 
    29     invalid_value = ''
 | 
| 
 | 
    30     invalid_column = 0
 | 
| 
 | 
    31 
 | 
| 
 | 
    32     for i, line in enumerate( file( sys.argv[1] ) ):
 | 
| 
 | 
    33         valid = True
 | 
| 
 | 
    34         line = line.rstrip('\n\r')
 | 
| 
 | 
    35 
 | 
| 
 | 
    36         if line and not line.startswith( '#' ): 
 | 
| 
 | 
    37             # Extract values and convert to floats
 | 
| 
 | 
    38             row = []
 | 
| 
 | 
    39             for column in columns:
 | 
| 
 | 
    40                 column -= 1
 | 
| 
 | 
    41                 fields = line.split( "\t" )
 | 
| 
 | 
    42                 if len( fields ) <= column:
 | 
| 
 | 
    43                     valid = False
 | 
| 
 | 
    44                 else:
 | 
| 
 | 
    45                     val = fields[column]
 | 
| 
 | 
    46                     if val.lower() == "na": 
 | 
| 
 | 
    47                         row.append( float( "nan" ) )
 | 
| 
 | 
    48                     else:
 | 
| 
 | 
    49                         try:
 | 
| 
 | 
    50                             row.append( float( fields[column] ) )
 | 
| 
 | 
    51                         except:
 | 
| 
 | 
    52                             valid = False
 | 
| 
 | 
    53                             skipped_lines += 1
 | 
| 
 | 
    54                             if not first_invalid_line:
 | 
| 
 | 
    55                                 first_invalid_line = i+1
 | 
| 
 | 
    56                                 invalid_value = fields[column]
 | 
| 
 | 
    57                                 invalid_column = column+1
 | 
| 
 | 
    58         else:
 | 
| 
 | 
    59             valid = False
 | 
| 
 | 
    60             skipped_lines += 1
 | 
| 
 | 
    61             if not first_invalid_line:
 | 
| 
 | 
    62                 first_invalid_line = i+1
 | 
| 
 | 
    63 
 | 
| 
 | 
    64         if valid:
 | 
| 
 | 
    65             # matrix.append( row )
 | 
| 
 | 
    66             matrix += row 
 | 
| 
 | 
    67 
 | 
| 
 | 
    68     if skipped_lines < i:
 | 
| 
 | 
    69         try:
 | 
| 
 | 
    70             out = open( sys.argv[2], "w" )
 | 
| 
 | 
    71         except:
 | 
| 
 | 
    72             stop_err( "Unable to open output file" )
 | 
| 
 | 
    73 
 | 
| 
 | 
    74         # Run correlation
 | 
| 
 | 
    75         # print >> sys.stderr, "matrix: %s" % matrix
 | 
| 
 | 
    76         # print >> sys.stderr, "array: %s" % array( matrix )
 | 
| 
 | 
    77         try:
 | 
| 
 | 
    78             # value = r.cor( array( matrix ), use="pairwise.complete.obs", method=method )
 | 
| 
 | 
    79             fv = robjects.FloatVector(matrix)
 | 
| 
 | 
    80             m = r['matrix'](fv, ncol=len(columns),byrow=True)
 | 
| 
 | 
    81             rslt_mat = r.cor(m, use="pairwise.complete.obs", method=method )
 | 
| 
 | 
    82             value = []
 | 
| 
 | 
    83             for ri in range(1, rslt_mat.nrow + 1):
 | 
| 
 | 
    84                 row = []
 | 
| 
 | 
    85                 for ci in range(1, rslt_mat.ncol + 1):
 | 
| 
 | 
    86 		    row.append(rslt_mat.rx(ri,ci)[0])
 | 
| 
 | 
    87                 value.append(row)
 | 
| 
 | 
    88         except Exception, exc:
 | 
| 
 | 
    89             out.close()
 | 
| 
 | 
    90             stop_err("%s" %str( exc ))
 | 
| 
 | 
    91         for row in value:
 | 
| 
 | 
    92             print >> out, "\t".join( map( str, row ) )
 | 
| 
 | 
    93         out.close()
 | 
| 
 | 
    94 
 | 
| 
 | 
    95     if skipped_lines > 0:
 | 
| 
 | 
    96         msg = "..Skipped %d lines starting with line #%d. " %( skipped_lines, first_invalid_line )
 | 
| 
 | 
    97         if invalid_value and invalid_column > 0:
 | 
| 
 | 
    98             msg += "Value '%s' in column %d is not numeric." % ( invalid_value, invalid_column )
 | 
| 
 | 
    99         print msg
 | 
| 
 | 
   100 
 | 
| 
 | 
   101 if __name__ == "__main__":
 | 
| 
 | 
   102     main()
 |