# coding=utf8

from datetime import timedelta

from auto_keyworder_modules import util, easydb_api, eas


class TagFilter(object):

    def __init__(self, tagfilter_def) -> None:
        self.all = []
        self.any = []
        self.no = []

        try:
            _all = util.get_json_value(tagfilter_def, 'all')
            if isinstance(_all, list) and len(_all) > 0:
                self.all = _all
            _any = util.get_json_value(tagfilter_def, 'any')
            if isinstance(_any, list) and len(_any) > 0:
                self.any = _any
            _not = util.get_json_value(tagfilter_def, 'not')
            if isinstance(_not, list) and len(_not) > 0:
                self.no = _not
        except:
            pass

    def is_valid(self) -> bool:
        return len(self.all) > 0 or len(self.any) > 0 or len(self.no) > 0

    def to_dict(self) -> dict:
        d = {}
        if len(self.all) > 0:
            d['all'] = self.all
        if len(self.any) > 0:
            d['any'] = self.any
        if len(self.no) > 0:
            d['not'] = self.no
        return d


def get_expired_timestamps(min_age_days: int) -> str:
    ts = util.now()
    if min_age_days == 0:
        return util.format_datetime(ts)
    ts = ts - timedelta(days=min_age_days)
    return util.format_datetime(ts)


def build_search_top_level(
    query: dict,
    limit: int = 1000,
    offset: int = 0,
) -> dict:
    return {
        'format': 'long',
        'generate_rights': False,
        'offset': offset,
        'limit': limit,
        'objecttypes': [],
        'search': [query],
    }


def do_search(
    easydb_server: str,
    easydb_token: str,
    query: dict,
) -> dict:
    try:
        statuscode, result = easydb_api.easydb_search(
            easydb_server,
            easydb_token,
            query,
        )
        if statuscode != 200:
            raise Exception(f'[{statuscode}] {result}')
        return result
    except Exception as e:
        return {
            'error': str(e),
        }


def build_query_from_tagfilter(tagfilter: TagFilter) -> dict:
    if not tagfilter.is_valid():
        return {}

    sub_searches = []

    if len(tagfilter.any) > 0:
        sub_searches.append(
            {
                'bool': 'should',
                'fields': [
                    '_tags._id',
                ],
                'in': tagfilter.any,
                'type': 'in',
            }
        )

    if len(tagfilter.no) > 0:
        sub_searches.append(
            {
                'bool': 'must_not',
                'fields': [
                    '_tags._id',
                ],
                'in': tagfilter.no,
                'type': 'in',
            }
        )

    if len(tagfilter.all) > 0:
        sub_searches.append(
            {
                'bool': 'must',
                'type': 'complex',
                'search': [
                    {
                        'bool': 'must',
                        'fields': [
                            '_tags._id',
                        ],
                        'in': [
                            tag,
                        ],
                        'type': 'in',
                    }
                    for tag in tagfilter.all
                ],
            }
        )

    if len(sub_searches) < 1:
        return None

    return {
        'bool': 'must',
        'search': sub_searches,
        'type': 'complex',
    }


def build_query_for_objecttype(objecttype: str) -> dict:
    return {
        'bool': 'must',
        'fields': [
            '_objecttype',
        ],
        'in': [
            objecttype,
        ],
        'type': 'in',
    }


def build_query_by_system_object_ids(sys_ids: list[int]) -> dict:
    return {
        'bool': 'must',
        'fields': [
            '_system_object_id',
        ],
        'in': sys_ids,
        'type': 'in',
    }


def build_query_by_field_values(
    objecttype: str,
    fieldname: str,
    keywords: list,
    language: str = None,
    offset: int = 0,
    limit: int = 1000,
) -> dict:
    field = f'{objecttype}.{fieldname}'

    searches = [
        build_query_for_objecttype(objecttype),
        {
            'bool': 'must',
            'fields': [
                field,
            ],
            'in': keywords,
            'type': 'in',
        },
    ]

    query = {
        'format': 'short',
        'generate_rights': False,
        'offset': offset,
        'limit': limit,
        'fields': [
            {
                'key': 'id',
                'field': f'{objecttype}._id',
            },
            {
                'key': 'keyword',
                'field': field,
            },
        ],
        'objecttypes': [
            objecttype,
        ],
        'search': searches,
    }

    if language is not None:
        query['language'] = language

    return query


def build_search_query(
    objecttype: str,
    tagfilter: TagFilter,
    asset_field: str,
    timestamp_field: str,
    min_age_days: int = None,
    offset: int = 0,
    limit: int = 1000,
) -> dict:

    searches = [
        build_query_for_objecttype(objecttype),
        {
            'bool': 'must_not',
            'fields': [
                f'{objecttype}.{asset_field}',
            ],
            'in': [
                None,
            ],
            'type': 'in',
        },
        {
            "type": "in",
            "fields": [
                f'{objecttype}.{asset_field}.class',
            ],
            "in": [
                "image",
            ],
            "bool": "must",
        },
    ]

    tf_search = build_query_from_tagfilter(tagfilter)
    if tf_search != {}:
        searches.append(tf_search)

    if min_age_days is not None:
        searches.append(
            {
                'bool': 'must',
                'search': [
                    {
                        'bool': 'should',
                        'field': f'{objecttype}.{timestamp_field}',
                        'to': get_expired_timestamps(min_age_days),
                        'type': 'range',
                    },
                    {
                        'bool': 'should',
                        f'fields': [
                            f'{objecttype}.{timestamp_field}',
                        ],
                        'in': [
                            None,
                        ],
                        'type': 'in',
                    },
                ],
                'type': 'complex',
            }
        )

    query = {
        'format': 'long',
        'generate_rights': False,
        'offset': offset,
        'limit': limit,
        'exclude_fields': [
            '_standard',
            '_format',
            '_created',
            '_last_modified',
            '_published_count',
            '_collections',
            '_uuid',
            '_global_object_id',
        ],
        'objecttypes': [
            objecttype,
        ],
        'search': searches,
    }

    return query


def map_search_results(
    result: dict,
    objecttype: str,
    object_map: dict[int, eas.AssetStatus],
) -> int:
    objects = util.get_json_value(result, 'objects')
    if not isinstance(objects, list):
        raise Exception(f'invalid response: {util.dumpjs(result)}')

    n = 0
    for o in objects:
        _sys_id = util.get_json_value(o, '_system_object_id', True)
        if _sys_id in object_map:
            continue

        _id = util.get_json_value(o, f'{objecttype}._id', True)
        mask = util.get_json_value(o, '_mask', True)
        tags = util.get_json_value(o, '_tags', True)

        object_map[_sys_id] = eas.AssetStatus(
            {
                '_mask': mask,
                '_tags': tags,
                '_objecttype': objecttype,
                '_system_object_id': _sys_id,
                objecttype: util.get_json_value(o, objecttype, True),
            }
        )

        n += 1

    return n


def collect_linked_objects_by_sys_id(
    easydb_server: str,
    easydb_token: str,
    sys_ids: list[int],
    limit: int = 1000,
) -> dict[str : dict[int, dict]]:

    # map: objecttype -> system_object_id -> obj
    collected_linked_objects = {}

    offset = 0
    while True:
        sys_id_batch = sys_ids[offset : offset + limit]
        if len(sys_id_batch) == 0:
            break

        query = build_query_by_system_object_ids(sys_id_batch)
        result = do_search(
            easydb_server=easydb_server,
            easydb_token=easydb_token,
            query=build_search_top_level(query=query),
        )
        offset += limit

        count = result.get('count')
        if not isinstance(count, int):
            break

        objects = result.get('objects')
        if not isinstance(objects, list):
            break

        for obj in objects:
            ot = obj.get('_objecttype')
            if not ot:
                continue
            sys_id = obj.get('_system_object_id')
            if not sys_id:
                continue

            if ot not in collected_linked_objects:
                collected_linked_objects[ot] = {}
            collected_linked_objects[ot][sys_id] = obj

    return collected_linked_objects
