#!/usr/bin/python
# -*- coding: utf-8 -*-
import sys, codecs
import sympy, math
import progressbar
#from helper import bag
import conv_table, io
import random

__doc__ = """Functions for extracting lexicon and tag examples from corpus dump."""

def from_dump(fname):
	"""Yields mte_conv.io KIPI tokens from given dump file.
	There will be no chunks/sents."""
	corp = codecs.open(fname, 'r', 'utf-8')
	for line in corp:
		sp, orth, lts = line.strip().split('\t', 2)
		space = (sp.lower() == 's')
		#orth = orth.lower()
		lts = map(unicode.split, lts.split('\t'))
		tok = io.Token(orth, not space)
		for lemma, tag in lts:
			tok.disamb_lexems.append([tag, lemma])
		# just take the first disamb as the selected one
		# (for gathering stats it doesn't matter which one is selected)
		tok.sel_d_lexem = tok.disamb_lexems[0]
		tok.disamb_lexems = tok.disamb_lexems[1:]
		yield tok
	corp.close()

def to_dump(fname, verbose = True):
	"""Creates a handler function ready to store each item from a source
	to a dump file (for converted tags)."""
	
	def handler(item_source):
		total = 0
		outcorp = codecs.open(fname, 'w', 'utf-8')
		for item in item_source:
			if isinstance(item, io.Token):
				outcorp.write(item.no_space and 'N' or 'S')
				outcorp.write('\t')
				outcorp.write(item.orth)
				for lex in item._without_duplicates(item.foreach_d_lexem()):
					outcorp.write('\t')
					outcorp.write(lex[1]) # lemma
					outcorp.write(' ')
					outcorp.write(lex[0]) # tag
				outcorp.write('\n')
				total += 1
				if verbose and total % 10000000 == 0:
					print '%d million tokens written' % (total / 1000000)
		outcorp.close()
	return handler
# TODO: lower orth when saving

def conv_dump(in_fname, out_fname):
	"""Converts the dump from KIPI to MTE."""
	print 'Converting the corpus dump from KIPI to MTE'
	conv = conv_table.Converter()
	input = from_dump(in_fname)
	output = to_dump(out_fname)
	output(conv.process(input))
	conv.dump_unks()

def items_full(fname):
	"""Yields (space, orth, lemma, tag, count) pairs.
	If more disambs, count will be less than one.
	Space denotes if the token is preceded by any white space."""
	corp = codecs.open(fname, 'r', 'utf-8')
	for line in corp:
		sp, orth, lts = line.strip().split('\t', 2)
		space = (sp.lower() == 's')
		orth = orth.lower()
		lts = lts.split('\t')
		
		denom = len(lts)
		for lt in lts:
			lemma,tag = lt.split()
			yield (space, orth, lemma.lower(), tag, float(1)/denom)
	corp.close()

def items_kipi(fname):
	"""Yields (space, orth, (lemma, tag) tuple) items.
	Tags are not converted."""
	corp = codecs.open(fname, 'r', 'utf-8')
	for line in corp:
		sp, orth, lts = line.strip().split('\t', 2)
		space = (sp.lower() == 's')
		orth = orth.lower()
		lts = tuple(map(tuple, map(unicode.split, lts.split('\t'))))
		
		yield (space, orth, lts)
	corp.close()

def get_lemmas(fname, how_many):
	"""Gets a list of (at most) how_many most frequent lemmas."""
	print 'Getting %d most frequent lemmas' % how_many
	lemmas = {}
	total = 0
	for sp, orth, lemma, tag, c in items_full(fname):
		if lemma in lemmas:
			lemmas[lemma] += c
		else:
			lemmas[lemma] = c
		total += 1
		if total % 10000000 == 0: print '%d lemmas analysed, %d gathered' % (total, len(lemmas))
	print 'All %d lemmas analysed. Sorting...' % total
	lemmas = sorted(lemmas.items(), lambda x,y: cmp(y[1],x[1]))
	print '%d lemmas total, taking %d' % (len(lemmas), how_many)
	# leave how_many, take only lemmas (discarding counts)
	# return also total number of items
	lemmas = lemmas[:how_many]
	# leave how_many, discard the rest
	print 'Most frequent occurs %.2f times, most rare -- %.2f times' % (lemmas[0][1], lemmas[-1][1])
	# discarding counts
	return map(lambda k: k[0], lemmas), total

def get_oltc(fname, wanted_lemmas, total):
	"""Gets the mapping (orth,lemma,tag) -> count."""
	print 'Getting the mapping (orth,lemma,tag) -> count'
	oltc = {}
	current = 0
	p = progressbar.ProgressBar(widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()])
	p.start()
	wanted_lemmas = set(wanted_lemmas)
	for sp, orth, lemma, tag, c in items_full(fname):
		if lemma in wanted_lemmas:
			t = (orth, lemma, tag)
			if t not in oltc:
				oltc[t] = c
			else:
				oltc[t] += c
		current += 1
		percent = 100 * current / total
		p.update(percent)
	return oltc

def get_tag_counts(fname, total):
	print 'Gathering tag counts'
	tag_count = {}
	current = 0
	p = progressbar.ProgressBar(widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()])
	p.start()
	for sp, orth, lemma, tag, c in items_full(fname):
		if tag not in tag_count:
			tag_count[tag] = c
		else:
			tag_count[tag] += c
		current += 1
		percent = 100 * current / total
		p.update(percent)
	print 'All %d tag occs analysed, %d gathered' % (total, len(tag_count))
	return tag_count

def get_tag_examples(fname, tag_count, total):
	print 'Gathering tag examples'
	MIN_FORMS = 10 # at least this many forms per tag to consider closing
	ENOUGH = 3 # at least this many forms occuring enough times to consider closing
	MIN_ENOUGH = 10 # at least this many occurences make it enough
	
	still_wanted = set(tag_count.keys())
	covered = set()
	forms_of_tag = {} # tag -> (orth,lemma) -> count
	
	current = 0
	
	p = progressbar.ProgressBar(widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()])
	p.start()
	for sp, orth, lemma, tag, c in items_full(fname):
		if tag in still_wanted:
			t = (orth, lemma)
			if tag not in forms_of_tag:
				forms_of_tag[tag] = {(orth, lemma): c}
			else:
				fl_count = forms_of_tag[tag]
				if t in fl_count:
					fl_count[t] += c
				else:
					fl_count[t] = c
		current += 1
		if current % 10000 == 0:
			# consider moving tags to 'covered'
			for tag in list(still_wanted):
				if tag in forms_of_tag:
					fl_count = forms_of_tag[tag]
					if len(fl_count) >= MIN_FORMS:
						enough = set([fl for fl in fl_count if fl_count[fl] > ENOUGH])
						if len(enough) >= MIN_ENOUGH:
							# leave only the forms appearing enough times
							new_fl_count = {}
							for fl in fl_count:
								if fl in enough:
									new_fl_count[fl] = fl_count[fl]
							if len(new_fl_count) >= ENOUGH:
								forms_of_tag[tag] = new_fl_count
							# move from 'still_wanted' to 'covered'
							still_wanted.remove(tag)
							covered.add(tag)
							#print '%s covered and closed (%d still wanted)' % (tag, len(still_wanted))
							#print
							#print new_fl_count
			# maybe all covered?
			if len(still_wanted) == 0:
				print 'All tags covered, terminating now.'
				break
		percent = 100 * current / total
		p.update(percent)
	### select representative examples (3 per tag)
	used_forms = set()
	tag_examples = {} # tag -> ((f1,l1),(f2,l2),(f2,l3),tag count)
	# sort tags beginning from less frequent
	listed = sorted(tag_count.items(), lambda x,y: cmp(x[1],y[1]))
	# iterate over tags, try selecting unused forms
	num_unspec = 0
	for tag,count in listed:
		# sort the orth/lemma pairs of this tag, beginning with most frequent ones
		forms_listed = sorted(forms_of_tag[tag].items(), lambda x,y: cmp(y[1],x[1]))
		if len(forms_listed) <= 3:
			# gather used ones
			for form,fc in forms_listed:
				used_forms.add(form)
			if len(forms_listed) < 3:
				num_unspec += 1
				# less than 3 examples; extend the list with nulls
				forms_listed.extend([(None, 0)] * (3-len(forms_listed)))
			assert len(forms_listed) == 3
		else:
			# more examples than needed (3)
			# iterate over forms, try to get the most frequent unused ones
			new_forms_listed = []
			rejd_forms = []
			for form,fc in forms_listed:
				if form not in used_forms:
					# great, take it as an example
					used_forms.add(form)
					new_forms_listed.append((form,fc))
					# maybe enough
					if len(new_forms_listed) == 3:
						break
				rejd_forms.append((form,fc))
			if len(new_forms_listed) < 3:
				# ok, just take what there is
				new_forms_listed.extend(rejd_forms[:3-len(new_forms_listed)])
			assert len(new_forms_listed) == 3
			forms_listed = new_forms_listed
		tag_examples[tag] = (forms_listed[0][0], forms_listed[1][0], forms_listed[2][0], count)
	print 'Done. %d tags partially uncovered' % num_unspec
	return tag_examples # tag -> ((f1,l1),(f2,l2),(f2,l3),tag count)

def write_oltc(oltc, outfname):
	out = codecs.open(outfname, 'w', 'utf-8')
	for (orth,lemma,tag),count in sorted(oltc.items()):
		out.write(u'%s\t%s\t%s\t%d\n' % (orth,lemma,tag,int(math.ceil(count))))
	out.close()

def write_tagc(tagc, outfname):
	out = codecs.open(outfname, 'w', 'utf-8')
	listed = sorted(tagc.items(), lambda x,y: cmp(y[1],x[1]))
	for item in listed:
		out.write(u'%s\t%.2f\n' % item)
	out.close()

def write_tage(tage, outfname):
	out = codecs.open(outfname, 'w', 'utf-8')
	for tag,(fl1,fl2,fl3,c) in sorted(tage.items()):
		fl1 = fl1 or ('****','****')
		fl2 = fl2 or ('****','****')
		fl3 = fl3 or ('****','****')
		out.write(u'%s\t%s/%s\t%s/%s\t%s/%s\t%d\n' % (tag,fl1[0],fl1[1],fl2[0],fl2[1],fl3[0],fl3[1],int(math.ceil(c))))
	out.close()

def now(kipi_fname, mte_fname):
	conv_dump(kipi_fname, mte_fname)
	lemmas, total = get_lemmas(mte_fname, 15000)
	oltc = get_oltc(mte_fname, lemmas, total)
	write_oltc(oltc, 'x_lexicon.txt')
	tagc = get_tag_counts(mte_fname, total)
	write_tagc(tagc, 'x_tagc.txt')
	tage = get_tag_examples(mte_fname, tagc, total)
	write_tage(tage, 'x_tage.txt')

if __name__ == '__main__':
	if len(sys.argv) < 3:
		print 'usage: %s KIPI_DUMP MTE_OUT_FNAME'
	else:
		now(sys.argv[1], sys.argv[2])

