# Copyright 2013 Google, Inc. All Rights Reserved.
#
# Google Author(s): Behdad Esfahbod, Roozbeh Pournader

from fontTools import ttLib
from fontTools.ttLib.tables.DefaultTable import DefaultTable
from fontTools.ttLib.tables import otTables
from fontTools.merge.base import add_method, mergeObjects
from fontTools.merge.util import *
import logging


log = logging.getLogger("fontTools.merge")


def mergeLookupLists(lst):
    # TODO Do smarter merge.
    return sumLists(lst)


def mergeFeatures(lst):
    assert lst
    self = otTables.Feature()
    self.FeatureParams = None
    self.LookupListIndex = mergeLookupLists(
        [l.LookupListIndex for l in lst if l.LookupListIndex]
    )
    self.LookupCount = len(self.LookupListIndex)
    return self


def mergeFeatureLists(lst):
    d = {}
    for l in lst:
        for f in l:
            tag = f.FeatureTag
            if tag not in d:
                d[tag] = []
            d[tag].append(f.Feature)
    ret = []
    for tag in sorted(d.keys()):
        rec = otTables.FeatureRecord()
        rec.FeatureTag = tag
        rec.Feature = mergeFeatures(d[tag])
        ret.append(rec)
    return ret


def mergeLangSyses(lst):
    assert lst

    # TODO Support merging ReqFeatureIndex
    assert all(l.ReqFeatureIndex == 0xFFFF for l in lst)

    self = otTables.LangSys()
    self.LookupOrder = None
    self.ReqFeatureIndex = 0xFFFF
    self.FeatureIndex = mergeFeatureLists(
        [l.FeatureIndex for l in lst if l.FeatureIndex]
    )
    self.FeatureCount = len(self.FeatureIndex)
    return self


def mergeScripts(lst):
    assert lst

    if len(lst) == 1:
        return lst[0]
    langSyses = {}
    for sr in lst:
        for lsr in sr.LangSysRecord:
            if lsr.LangSysTag not in langSyses:
                langSyses[lsr.LangSysTag] = []
            langSyses[lsr.LangSysTag].append(lsr.LangSys)
    lsrecords = []
    for tag, langSys_list in sorted(langSyses.items()):
        lsr = otTables.LangSysRecord()
        lsr.LangSys = mergeLangSyses(langSys_list)
        lsr.LangSysTag = tag
        lsrecords.append(lsr)

    self = otTables.Script()
    self.LangSysRecord = lsrecords
    self.LangSysCount = len(lsrecords)
    dfltLangSyses = [s.DefaultLangSys for s in lst if s.DefaultLangSys]
    if dfltLangSyses:
        self.DefaultLangSys = mergeLangSyses(dfltLangSyses)
    else:
        self.DefaultLangSys = None
    return self


def mergeScriptRecords(lst):
    d = {}
    for l in lst:
        for s in l:
            tag = s.ScriptTag
            if tag not in d:
                d[tag] = []
            d[tag].append(s.Script)
    ret = []
    for tag in sorted(d.keys()):
        rec = otTables.ScriptRecord()
        rec.ScriptTag = tag
        rec.Script = mergeScripts(d[tag])
        ret.append(rec)
    return ret


otTables.ScriptList.mergeMap = {
    "ScriptCount": lambda lst: None,  # TODO
    "ScriptRecord": mergeScriptRecords,
}
otTables.BaseScriptList.mergeMap = {
    "BaseScriptCount": lambda lst: None,  # TODO
    # TODO: Merge duplicate entries
    "BaseScriptRecord": lambda lst: sorted(
        sumLists(lst), key=lambda s: s.BaseScriptTag
    ),
}

otTables.FeatureList.mergeMap = {
    "FeatureCount": sum,
    "FeatureRecord": lambda lst: sorted(sumLists(lst), key=lambda s: s.FeatureTag),
}

otTables.LookupList.mergeMap = {
    "LookupCount": sum,
    "Lookup": sumLists,
}

otTables.Coverage.mergeMap = {
    "Format": min,
    "glyphs": sumLists,
}

otTables.ClassDef.mergeMap = {
    "Format": min,
    "classDefs": sumDicts,
}

otTables.LigCaretList.mergeMap = {
    "Coverage": mergeObjects,
    "LigGlyphCount": sum,
    "LigGlyph": sumLists,
}

otTables.AttachList.mergeMap = {
    "Coverage": mergeObjects,
    "GlyphCount": sum,
    "AttachPoint": sumLists,
}

# XXX Renumber MarkFilterSets of lookups
otTables.MarkGlyphSetsDef.mergeMap = {
    "MarkSetTableFormat": equal,
    "MarkSetCount": sum,
    "Coverage": sumLists,
}

otTables.Axis.mergeMap = {
    "*": mergeObjects,
}

# XXX Fix BASE table merging
otTables.BaseTagList.mergeMap = {
    "BaseTagCount": sum,
    "BaselineTag": sumLists,
}

otTables.GDEF.mergeMap = (
    otTables.GSUB.mergeMap
) = (
    otTables.GPOS.mergeMap
) = otTables.BASE.mergeMap = otTables.JSTF.mergeMap = otTables.MATH.mergeMap = {
    "*": mergeObjects,
    "Version": max,
}

ttLib.getTableClass("GDEF").mergeMap = ttLib.getTableClass(
    "GSUB"
).mergeMap = ttLib.getTableClass("GPOS").mergeMap = ttLib.getTableClass(
    "BASE"
).mergeMap = ttLib.getTableClass(
    "JSTF"
).mergeMap = ttLib.getTableClass(
    "MATH"
).mergeMap = {
    "tableTag": onlyExisting(equal),  # XXX clean me up
    "table": mergeObjects,
}


@add_method(ttLib.getTableClass("GSUB"))
def merge(self, m, tables):
    assert len(tables) == len(m.duplicateGlyphsPerFont)
    for i, (table, dups) in enumerate(zip(tables, m.duplicateGlyphsPerFont)):
        if not dups:
            continue
        if table is None or table is NotImplemented:
            log.warning(
                "Have non-identical duplicates to resolve for '%s' but no GSUB. Are duplicates intended?: %s",
                m.fonts[i]._merger__name,
                dups,
            )
            continue

        synthFeature = None
        synthLookup = None
        for script in table.table.ScriptList.ScriptRecord:
            if script.ScriptTag == "DFLT":
                continue  # XXX
            for langsys in [script.Script.DefaultLangSys] + [
                l.LangSys for l in script.Script.LangSysRecord
            ]:
                if langsys is None:
                    continue  # XXX Create!
                feature = [v for v in langsys.FeatureIndex if v.FeatureTag == "locl"]
                assert len(feature) <= 1
                if feature:
                    feature = feature[0]
                else:
                    if not synthFeature:
                        synthFeature = otTables.FeatureRecord()
                        synthFeature.FeatureTag = "locl"
                        f = synthFeature.Feature = otTables.Feature()
                        f.FeatureParams = None
                        f.LookupCount = 0
                        f.LookupListIndex = []
                        table.table.FeatureList.FeatureRecord.append(synthFeature)
                        table.table.FeatureList.FeatureCount += 1
                    feature = synthFeature
                    langsys.FeatureIndex.append(feature)
                    langsys.FeatureIndex.sort(key=lambda v: v.FeatureTag)

                if not synthLookup:
                    subtable = otTables.SingleSubst()
                    subtable.mapping = dups
                    synthLookup = otTables.Lookup()
                    synthLookup.LookupFlag = 0
                    synthLookup.LookupType = 1
                    synthLookup.SubTableCount = 1
                    synthLookup.SubTable = [subtable]
                    if table.table.LookupList is None:
                        # mtiLib uses None as default value for LookupList,
                        # while feaLib points to an empty array with count 0
                        # TODO: make them do the same
                        table.table.LookupList = otTables.LookupList()
                        table.table.LookupList.Lookup = []
                        table.table.LookupList.LookupCount = 0
                    table.table.LookupList.Lookup.append(synthLookup)
                    table.table.LookupList.LookupCount += 1

                if feature.Feature.LookupListIndex[:1] != [synthLookup]:
                    feature.Feature.LookupListIndex[:0] = [synthLookup]
                    feature.Feature.LookupCount += 1

    DefaultTable.merge(self, m, tables)
    return self


@add_method(
    otTables.SingleSubst,
    otTables.MultipleSubst,
    otTables.AlternateSubst,
    otTables.LigatureSubst,
    otTables.ReverseChainSingleSubst,
    otTables.SinglePos,
    otTables.PairPos,
    otTables.CursivePos,
    otTables.MarkBasePos,
    otTables.MarkLigPos,
    otTables.MarkMarkPos,
)
def mapLookups(self, lookupMap):
    pass


# Copied and trimmed down from subset.py
@add_method(
    otTables.ContextSubst,
    otTables.ChainContextSubst,
    otTables.ContextPos,
    otTables.ChainContextPos,
)
def __merge_classify_context(self):
    class ContextHelper(object):
        def __init__(self, klass, Format):
            if klass.__name__.endswith("Subst"):
                Typ = "Sub"
                Type = "Subst"
            else:
                Typ = "Pos"
                Type = "Pos"
            if klass.__name__.startswith("Chain"):
                Chain = "Chain"
            else:
                Chain = ""
            ChainTyp = Chain + Typ

            self.Typ = Typ
            self.Type = Type
            self.Chain = Chain
            self.ChainTyp = ChainTyp

            self.LookupRecord = Type + "LookupRecord"

            if Format == 1:
                self.Rule = ChainTyp + "Rule"
                self.RuleSet = ChainTyp + "RuleSet"
            elif Format == 2:
                self.Rule = ChainTyp + "ClassRule"
                self.RuleSet = ChainTyp + "ClassSet"

    if self.Format not in [1, 2, 3]:
        return None  # Don't shoot the messenger; let it go
    if not hasattr(self.__class__, "_merge__ContextHelpers"):
        self.__class__._merge__ContextHelpers = {}
    if self.Format not in self.__class__._merge__ContextHelpers:
        helper = ContextHelper(self.__class__, self.Format)
        self.__class__._merge__ContextHelpers[self.Format] = helper
    return self.__class__._merge__ContextHelpers[self.Format]


@add_method(
    otTables.ContextSubst,
    otTables.ChainContextSubst,
    otTables.ContextPos,
    otTables.ChainContextPos,
)
def mapLookups(self, lookupMap):
    c = self.__merge_classify_context()

    if self.Format in [1, 2]:
        for rs in getattr(self, c.RuleSet):
            if not rs:
                continue
            for r in getattr(rs, c.Rule):
                if not r:
                    continue
                for ll in getattr(r, c.LookupRecord):
                    if not ll:
                        continue
                    ll.LookupListIndex = lookupMap[ll.LookupListIndex]
    elif self.Format == 3:
        for ll in getattr(self, c.LookupRecord):
            if not ll:
                continue
            ll.LookupListIndex = lookupMap[ll.LookupListIndex]
    else:
        assert 0, "unknown format: %s" % self.Format


@add_method(otTables.ExtensionSubst, otTables.ExtensionPos)
def mapLookups(self, lookupMap):
    if self.Format == 1:
        self.ExtSubTable.mapLookups(lookupMap)
    else:
        assert 0, "unknown format: %s" % self.Format


@add_method(otTables.Lookup)
def mapLookups(self, lookupMap):
    for st in self.SubTable:
        if not st:
            continue
        st.mapLookups(lookupMap)


@add_method(otTables.LookupList)
def mapLookups(self, lookupMap):
    for l in self.Lookup:
        if not l:
            continue
        l.mapLookups(lookupMap)


@add_method(otTables.Lookup)
def mapMarkFilteringSets(self, markFilteringSetMap):
    if self.LookupFlag & 0x0010:
        self.MarkFilteringSet = markFilteringSetMap[self.MarkFilteringSet]


@add_method(otTables.LookupList)
def mapMarkFilteringSets(self, markFilteringSetMap):
    for l in self.Lookup:
        if not l:
            continue
        l.mapMarkFilteringSets(markFilteringSetMap)


@add_method(otTables.Feature)
def mapLookups(self, lookupMap):
    self.LookupListIndex = [lookupMap[i] for i in self.LookupListIndex]


@add_method(otTables.FeatureList)
def mapLookups(self, lookupMap):
    for f in self.FeatureRecord:
        if not f or not f.Feature:
            continue
        f.Feature.mapLookups(lookupMap)


@add_method(otTables.DefaultLangSys, otTables.LangSys)
def mapFeatures(self, featureMap):
    self.FeatureIndex = [featureMap[i] for i in self.FeatureIndex]
    if self.ReqFeatureIndex != 65535:
        self.ReqFeatureIndex = featureMap[self.ReqFeatureIndex]


@add_method(otTables.Script)
def mapFeatures(self, featureMap):
    if self.DefaultLangSys:
        self.DefaultLangSys.mapFeatures(featureMap)
    for l in self.LangSysRecord:
        if not l or not l.LangSys:
            continue
        l.LangSys.mapFeatures(featureMap)


@add_method(otTables.ScriptList)
def mapFeatures(self, featureMap):
    for s in self.ScriptRecord:
        if not s or not s.Script:
            continue
        s.Script.mapFeatures(featureMap)


def layoutPreMerge(font):
    # Map indices to references

    GDEF = font.get("GDEF")
    GSUB = font.get("GSUB")
    GPOS = font.get("GPOS")

    for t in [GSUB, GPOS]:
        if not t:
            continue

        if t.table.LookupList:
            lookupMap = {i: v for i, v in enumerate(t.table.LookupList.Lookup)}
            t.table.LookupList.mapLookups(lookupMap)
            t.table.FeatureList.mapLookups(lookupMap)

            if (
                GDEF
                and GDEF.table.Version >= 0x00010002
                and GDEF.table.MarkGlyphSetsDef
            ):
                markFilteringSetMap = {
                    i: v for i, v in enumerate(GDEF.table.MarkGlyphSetsDef.Coverage)
                }
                t.table.LookupList.mapMarkFilteringSets(markFilteringSetMap)

        if t.table.FeatureList and t.table.ScriptList:
            featureMap = {i: v for i, v in enumerate(t.table.FeatureList.FeatureRecord)}
            t.table.ScriptList.mapFeatures(featureMap)

    # TODO FeatureParams nameIDs


def layoutPostMerge(font):
    # Map references back to indices

    GDEF = font.get("GDEF")
    GSUB = font.get("GSUB")
    GPOS = font.get("GPOS")

    for t in [GSUB, GPOS]:
        if not t:
            continue

        if t.table.FeatureList and t.table.ScriptList:
            # Collect unregistered (new) features.
            featureMap = GregariousIdentityDict(t.table.FeatureList.FeatureRecord)
            t.table.ScriptList.mapFeatures(featureMap)

            # Record used features.
            featureMap = AttendanceRecordingIdentityDict(
                t.table.FeatureList.FeatureRecord
            )
            t.table.ScriptList.mapFeatures(featureMap)
            usedIndices = featureMap.s

            # Remove unused features
            t.table.FeatureList.FeatureRecord = [
                f
                for i, f in enumerate(t.table.FeatureList.FeatureRecord)
                if i in usedIndices
            ]

            # Map back to indices.
            featureMap = NonhashableDict(t.table.FeatureList.FeatureRecord)
            t.table.ScriptList.mapFeatures(featureMap)

            t.table.FeatureList.FeatureCount = len(t.table.FeatureList.FeatureRecord)

        if t.table.LookupList:
            # Collect unregistered (new) lookups.
            lookupMap = GregariousIdentityDict(t.table.LookupList.Lookup)
            t.table.FeatureList.mapLookups(lookupMap)
            t.table.LookupList.mapLookups(lookupMap)

            # Record used lookups.
            lookupMap = AttendanceRecordingIdentityDict(t.table.LookupList.Lookup)
            t.table.FeatureList.mapLookups(lookupMap)
            t.table.LookupList.mapLookups(lookupMap)
            usedIndices = lookupMap.s

            # Remove unused lookups
            t.table.LookupList.Lookup = [
                l for i, l in enumerate(t.table.LookupList.Lookup) if i in usedIndices
            ]

            # Map back to indices.
            lookupMap = NonhashableDict(t.table.LookupList.Lookup)
            t.table.FeatureList.mapLookups(lookupMap)
            t.table.LookupList.mapLookups(lookupMap)

            t.table.LookupList.LookupCount = len(t.table.LookupList.Lookup)

            if GDEF and GDEF.table.Version >= 0x00010002:
                markFilteringSetMap = NonhashableDict(
                    GDEF.table.MarkGlyphSetsDef.Coverage
                )
                t.table.LookupList.mapMarkFilteringSets(markFilteringSetMap)

    # TODO FeatureParams nameIDs
