# -*- coding: utf-8 -*-
import io
import sys, os, codecs, re

__doc__ = """The converter functionality. Main conversion table (KIPI-tag, MTE-tag, subtable),
lemma-based conversion table (lemma, tag or lemma, tag parts), regex matcher.

Adam Radziszewski, 01.09.2009
"""

def get_path(fname):
	"""Gets the path to given file assuming it's in the same dir as this module."""
	return os.path.join(os.path.dirname(__file__), fname)

def warn(msg):
	sys.stderr.write('WARNING: ')
	sys.stderr.write(msg)
	sys.stderr.write('\n')

class Converter:
	"""The converter itself. Initialised with conversion tables, is ready to
	process input."""
	
	MAIN_TABLE = 'main'
	PUNCT_TAG = 'interp'
	UNKNOWN = 'X'
	
	pron_aglut_p = re.compile('Pp.......a') # agglutinated pronoun (-ń)
	ppron3_p = re.compile('Pp.3') # 3rd person pronoun
	ppron3_lemma = {'m': u'on', 'f': u'ona', 'n': u'ono'} # lemma restoration mapping for pprons
	ger_pref = 'Ng' # prefix of gerund tags
	ger_orth_p = re.compile('ni\w{1,2}$') # pattern for gerund form endings -- if matched, change the lemma
	ger_lemma_end = 'nie' # desired ending of gerund lemmas
	praet_p = re.compile('Vm.(is|cp)') # praet (past verb) or praet+by (conjunctive present)
	aglt_p = re.compile('Va.i......a') # aglt (agglutinated tense-bearing segment like -em)
	
	def __init__(self):
		self.table = MainTable(Converter.MAIN_TABLE)
	
	def process(self, item_source):
		"""Processes the iterable item_source, generating output items."""
		for item in self._stage4(self._stage3(self._stage2(self._stage1(item_source)))):
			yield item
		
		self.table.dump_unks()
	
	def convert_single_tag(self, tag, lemma, orth):
		"""Converts single tag (lemma and orth should also be provided).
		If punctuation tag is provided, will return None (as punctuation
		marks are not treated as tokens in MTE). Otherwise will return
		MTE tag as a string.
		NOTE: as only one tag is available, agglutination will not be
		reverted from non-agglutinated tokens (although this will probably
		not be a problem as Morfeusz does not seem to produce tags problematic
		for the current conversion tables)."""
		if tag == Converter.PUNCT_TAG:
			return None
		lex = [tag, lemma]
		self.convert_lexem(lex, orth)
		return lex
	
	def convert_lexem(self, lex, orth):
		"""Converts given lexem, i.e. a two-element list [tag, lemma].
		Also provide orth for repairing lemmas."""
		lex[0] = self.table.get_tag(lex[0], lex[1], None)
		# repair lemma: personal pronouns
		if Converter.ppron3_p.match(lex[0]):
			# replace the lemma
			lex[1] = Converter.ppron3_lemma[lex[0][4]]
		# repair lemma: gerunds (infinitive -> nominative)
		elif lex[0].startswith(Converter.ger_pref):
			lex[1] = Converter.ger_orth_p.sub(Converter.ger_lemma_end, orth.lower())
	
	def _stage1(self, item_source):
		"""Preprocess: convert punctuation."""
		for item in item_source:
			# convert punctuation
			if isinstance(item, io.Token) and item.sel_d_lexem[0] == Converter.PUNCT_TAG:
				yield io.Punctuation(item.orth)
			else:
				yield item
	
	def _stage2(self, item_source):
		"""The main conversion stage -- using the tables."""
		for item in item_source:
			if isinstance(item, io.Token):
				# convert each lexem
				for lex in item.foreach_lexem():
					self.convert_lexem(lex, item.orth)
			
			yield item
	
	def _stage3(self, item_source):
		"""Post-processing: adding agglutination to personal pronouns."""
		"""The main conversion stage -- using the tables."""
		for item in item_source:
			self._repair_agglut_pron(item)
			yield item
	
	def _stage4(self, item_source):
		"""Post-processing: joining praet + aglt or praet + Q(by) + aglt."""
		baseitem = None
		
		def is_praet(item):
			# must be token
			if not isinstance(item, io.Token): return False
			# all disamb tags satisfy the praet pattern
			for lex in item.foreach_d_lexem():
				if not Converter.praet_p.match(lex[0]):
					return False
			return True
		
		def is_by(item):
			# must be token after no space
			if not isinstance(item, io.Token) or not item.no_space: return False
			# form == by
			return item.orth.lower() == 'by'
		
		def is_aglt(item):
			# must be token after no space
			if not isinstance(item, io.Token) or not item.no_space: return False
			# all disamb tags satisfy the aglt pattern
			for lex in item.foreach_d_lexem():
				if not Converter.aglt_p.match(lex[0]):
					return False
			return True
		
		def combine_aglt(praet, aglt):
			pers = aglt.sel_d_lexem[0][5]
			for lex in praet.foreach_lexem():
				# change the tags that match the aglt pattern
				if Converter.praet_p.match(lex[0]):
					mutable = list(lex[0])
					mutable.extend(['-'] * (11 - len(mutable))) # fill the rest with unspecifieds
					mutable[5] = pers # take person from aglt
					mutable[9] = '-' # clear definitenes/form length
					mutable[10] = 'y' # clitic = y (as we are combining)
					lex[0] = ''.join(mutable)
			praet.orth += aglt.orth
		def combine_by(praet, by):
			for lex in praet.foreach_lexem():
				# change the tags that match the praet pattern
				if Converter.praet_p.match(lex[0]):
					mutable = list(lex[0])
					mutable.extend(['-'] * (11 - len(mutable))) # fill the rest with unspecifieds
					mutable[3] = 'c' # conjunctive
					mutable[4] = 'p' # present
					if mutable[5] == '-': # if person unspecified,
						mutable[5] = '3' # insert 3rd person (may be overridden by aglt)
					mutable[10] = 'y' # clitic = y (we are combining)
					lex[0] = ''.join(mutable)
			praet.orth += by.orth
		try:
			baseitem = item_source.next()
			while True:
				if is_praet(baseitem):
					# praet, maybe combine
					next = item_source.next()
					if is_aglt(next):
						# combine with aglt, yield it now (no possibility of combining it further)
						combine_aglt(baseitem, next)
						yield baseitem
						baseitem = None
						baseitem = item_source.next()
					elif is_by(next):
						# combine with by and check for optional aglt
						combine_by(baseitem, next)
						next = item_source.next()
						if is_aglt(next):
							# combine with aglt, yield it now
							combine_aglt(baseitem, next)
							yield baseitem
							baseitem = None
							baseitem = item_source.next()
						else:
							# flush the combined baseitem, use next as incoming baseitem
							yield baseitem
							baseitem = next
					else:
						# flush the uncombined baseitem, use next as incoming baseitem
						yield baseitem
						baseitem = next
				else:
					# not praet, no combining
					yield baseitem
					baseitem = None
					baseitem = item_source.next()
		except StopIteration:
			if baseitem is not None:
				yield baseitem
	
	def _repair_agglut_pron(self, item):
		if not isinstance(item, io.Token) or item.no_space:
			return False
		chg = False
		for lex in item.foreach_lexem():
			tag = lex[0]
			if Converter.pron_aglut_p.match(tag):
				# the tag is agglutinative personal pronoun although a space precedes it
				# repair clitic=agglut -> clitic=yes, def:any -> unspec
				lex[0] = tag[:9] + 'y-' + tag[11]
				chg = True
		return chg
	
	def dump_unks(self):
		self.table.dump_unks()

class Table:
	"""Base for all conversion tables."""
	EXT = '.tab'
	
	def __init__(self, fname, verbose = True):
		path = get_path(fname)
		self.verbose = verbose
		self.load(path + Table.EXT)
	
	def _entries(self, path):
		"""Reads the file, yielding its rows. each represented as a list of entries."""
		f = codecs.open(path, 'r', 'utf-8')
		for line in f:
			if not line.startswith('#'):
				elems = line.split()
				if len(elems) > 0:
					yield elems
		f.close()

class LemmaTable(Table):
	"""Conversion subtable -- for decisions based on lemmas.
	The table contains either full tags (is_tag_table == True)
	or tag parts (then idcs are the indices of the attributes
	stored; the list may be empty -- then the table contains
	lemmas only -- sort of a lexicon)."""
	
	def __init__(self, fname, is_tag_table, idcs = None):
		self.is_tag_table = is_tag_table
		self.idcs = idcs
		Table.__init__(self, fname, False)
	
	def load(self, path):
		num_elems = self.is_tag_table and 2 or (len(self.idcs) + 1)
		self.lemma_to_tagpart = {}
		self.name = os.path.basename(path) # for generating debug info
		for elems in self._entries(path):
			if len(elems) != num_elems:
				raise IOError('Unexpected format in lemma table %s: %s (expected %d fields)' % (path, ' '.join(elems), num_elems))
			lemmas = elems[0].split(',')
			# get the tag parts as one string (if it's the full tag, there is only one element -> no change)
			tagpart = ''.join(elems[1:])
			# add the tagpart for each lemma in comma-separated list
			for lemma in lemmas:
				self.lemma_to_tagpart[lemma] = tagpart
	
	def _matches(self, mte_tag, part):
		"Returns if given tag matches the tag part description."
		for pos, tag_idx in enumerate(self.idcs):
			if mte_tag[tag_idx] != part[pos]:
				return False
		return True
	
	def get_tag(self, tag, lemma, candidate_tags):
		"""Returns the proper MSD-tag if the necessary
		conditions are satisfied or returns None elsewhere.
		Candidate tags is the set of tags assigned to this
		subtable and this KIPI-tag in the main table.
		The set should contain exactly one MTE-tag,
		containing question marks if this is a tag-part
		table."""
		
		assert len(candidate_tags) == 1,('%s: the following set of candidates supplied for %s/%s: {%s} (expected exactly 1 tag/pattern)' % (self.name, lemma, tag, ', '.join(candidate_tags)))
		if lemma not in self.lemma_to_tagpart:
			#self._info('lemma %s not found' % lemma)
			return None
		part = self.lemma_to_tagpart[lemma]
		if self.is_tag_table:
			# this is the whole mte-tag
			return part
		else:
			pattern = candidate_tags[0]
			result = list(pattern) # list is mutable
			pos = -1
			for val in part:
				pos = pattern.find('?', pos + 1)
				assert pos >= 0, 'Not enough question marks in pattern %s (trying to fill with %s)' % (pattern, part)
				result[pos] = val
			return ''.join(result) # get the string back
	
	def _info(self, msg):
		"Display debug info."
		sys.stderr.write('LT %s (%s): %s\n' % (self.name, self.is_tag_table and 'full tag' or 'tag part', msg))

class RegexTable(Table):
	"""Regex matcher pretending a conversion table (for convenience)."""
	
	def __init__(self, regex_text):
		self.pattern = re.compile(regex_text)
	
	def load(self, nevermind):
		# if anyone calls it
		pass
	
	def get_tag(self, tag, lemma, candidate_tags):
		assert len(candidate_tags) == 1, 'Regex with more tags to choose'
		if self.pattern.match(lemma):
			return list(candidate_tags)[0]
		return None

class MainTable(Table):
	"""Main conversion table.
	The data is read from a file with entries (KIPI-tag, MTE-tag, optional subtable name).
	The data is stored internally (tag_to_conv) as a mapping from KIPI-tags to functions,
	i.e. KIPI-tag -> lemma -> MTE-tag."""
	ELSE = 'else'
	
	TAB_REGEX = 'regex'
	TAB_LEMMA = 'lemma'
	TAB_TAG = 'tag'
	
	def _ensure_subtable_loaded(self, xref):
		"""Resolves sub-table cross-reference and loads it if not loaded.
		Returns table name or None if cannot load."""
		if xref is None:
			return None
		tabtype, tabname = xref.split(':', 1)
		if tabname in self.subtables:
			return tabname
		
		if tabtype == MainTable.TAB_REGEX:
			self.subtables[tabname] = RegexTable(tabname)
			return tabname
		elif tabtype.startswith(MainTable.TAB_LEMMA):
			elems = tabtype.split(',')
			if len(elems) > 1 and elems[1] == MainTable.TAB_TAG:
				# load lemma-tag table
				self.subtables[tabname] = LemmaTable(tabname, True)
				return tabname
			else:
				# load lemma-tagpart table
				idcs = map(int, elems[1:])
				self.subtables[tabname] = LemmaTable(tabname, False, idcs)
				return tabname
		warn('%s: unknown cross-reference in tag table' % xref)
		return None
	
	def get_tag(self, tag, lemma, candidate_tags):
		if tag in self.tag_to_conv:
			return self.tag_to_conv[tag](lemma)
		else:
			self._gather_unknown(tag)
			return Converter.UNKNOWN
	
	def _gather_unknown(self, tag):
		if self.verbose:
			self.unks.add(tag)
	
	def dump_unks(self):
		if self.verbose and len(self.unks) > 0:
			sys.stderr.write('%d unknown tags occurred\n' % len(self.unks))
			for tag in self.unks:
				sys.stderr.write('\t')
				sys.stderr.write(tag)
				sys.stderr.write('\n')
	
	def load(self, path):
		self.unks = set()
		self.subtables = {}
		tag_to_pairs = {}
		
		# gather the mapping: tag -> (mte_tag, tabname) list, preload subtables
		for elems in self._entries(path):
			if len(elems) == 1 or len(elems) > 3:
				raise IOError('Unexpected format: %s' % line)
			tag = elems[0]
			mte_tag = elems[1]
			xref = (len(elems) > 2) and elems[2] != MainTable.ELSE and elems[2] or None
			tabname = self._ensure_subtable_loaded(xref)
			# ignore the entries with unhandled xrefs
			if xref is None or tabname is not None:
				if tag not in tag_to_pairs:
					tag_to_pairs[tag] = []
				tag_to_pairs[tag].append((mte_tag, tabname))
		# clean it up
		self.tag_to_conv = {}
		for tag in tag_to_pairs:
			pairs = tag_to_pairs[tag]
			assert len(pairs) > 0, 'Internal error, zero elements at a tag'
			if len(pairs) == 1:
				# we've got only one entry assigned to the tag
				# assert it is not provided with a cross-ref
				mte_tag, xref = pairs[0]
				if xref is not None:
					warn('tag %s converted to %s only, yet supplied with cross-ref %s' % (tag, mte_tag, xref))
				self.tag_to_conv[tag] = self._make_const(mte_tag)
			else:
				# assert there is at most one non-cross-ref entry (fallback)
				fallbacks = [pair[0] for pair in pairs if pair[1] is None]
				if len(fallbacks) > 1:
					warn('tag %s conversion ambiguity: %s' % (tag, ', '.join(fallbacks)))
				
				# check if there are xrefs
				if len(fallbacks) == len(pairs):
					# no xrefs, just use the fallback
					mte_tag = fallbacks[0]
					self.tag_to_conv[tag] = self._make_const(mte_tag)
				else:
					# there are xrefs
					# there should be a fallback
					if len(fallbacks) == 0:
						warn('no default value (else) for tag %s' % tag)
						fallback = Converter.UNKNOWN
					else:
						fallback = fallbacks[0]
					# gather the MTE tags assigned to xref tags
					xref_to_mtes = {}
					for mte_tag, xref in pairs:
						if xref is not None:
							if xref not in xref_to_mtes:
								xref_to_mtes[xref] = []
							xref_to_mtes[xref].append(mte_tag)
					self.tag_to_conv[tag] = self._make_rule(tag, xref_to_mtes, fallback)
	
	def _make_rule(self, tag, xref_to_mtes, fallback):
		"""Prepares a conversion rule for given xrefs and fallback.
		The rule will try to apply xref subtables and if unsuccessful,
		will mark the fallback tag."""
		
		def rule(lemma):
			for tabname in xref_to_mtes:
				mte_tag = self.subtables[tabname].get_tag(tag, lemma, xref_to_mtes[tabname])
				if mte_tag is not None:
					return mte_tag
			return fallback
		return rule
	
	def _make_const(self, mte_tag):
		return lambda whatever: mte_tag
