User:The Anome/Naive Bayes WikiProject classifier/naive bayes.py

From Wikipedia, the free encyclopedia
# Parse a TSV generated by a Quarry query, build count tables
# Quarry query: https://quarry.wmcloud.org/query/77172

import sys

def inc_dict(dic, key, count):
   dic[key] = dic.get(key, 0) + count

def count_dict(dic):
   return sum(dic.values())

def norm_dict(dic):
   count = count_dict(dic)
   return {k: v/count for k, v in dic.items()}

def gen_conditional(prob_XC, prob_X):
   return {k: v/prob_X[k[1]] for k, v in prob_XC.items()}

def main():
   datafile = sys.stdin

   first_line = datafile.readline().strip()

   if first_line != "cl_to\tcl_to_2\tmy_count":
      raise Exception("file is not in correct query data format")

   count_X = {}
   count_C = {}
   count_X_and_C = {}
   count_C_and_X = {}

   for line in datafile:
      line = line.split()
      if len(line) != 3:
         raise Exception("malformed input line")
      X, C, count = line
      count = int(count)
      inc_dict(count_X, X, count)
      inc_dict(count_C, C, count)
      inc_dict(count_X_and_C, (X, C), count)
      inc_dict(count_C_and_X, (C, X), count)

   # Now normalise to generate probabilities
   prob_X = norm_dict(count_X)
   prob_C = norm_dict(count_C)
   prob_X_and_C = norm_dict(count_X_and_C)
   prob_C_and_X = norm_dict(count_C_and_X)

   # And check these add up to 1
   print(count_dict(prob_X_and_C))

   # Now generate conditional probabilities
   prob_C_given_X = gen_conditional(prob_C_and_X, prob_X)
   for i in prob_C_given_X.items():
      print(i)

main()