from auto_keyworder_modules import datamodel, search


class AiServiceConfiguration(object):

    # variables that should be the same in all configurations
    variables: dict

    # loaded field infos by field name
    field_infos: dict[str, datamodel.FieldInfo] = {}

    def __init__(self) -> None:
        self.variables = {
            # shared variables (upper part of baseconfig)
            'easydb_password': None,
            'easydb_user': None,
            'enabled': None,
            'request_status_delay': None,
            'request_status_repititions': None,
            'start_now': None,
            # specific common variables (same for the different service configurations)
            'api_key': None,
            'api_url': None,
            'api_secret': None,
            'asset_field': None,
            'asset_version': None,
            'config_enabled': None,
            'min_age_days': None,
            'name': None,
            'objecttype': None,
            'tagfilter': None,
            'timestamp_field': None,
            'language': None,
        }

    # --------------------------------------------------------------------

    def parse_specific_variables(self, config_js: dict) -> None: ...

    def parse_base_config_element(
        self,
        shared_config: dict,
        specific_config: dict,
    ) -> bool:
        valid = False

        for k in self.variables:
            config_value = shared_config.get(k)
            if config_value is not None:
                self.variables[k] = config_value
                valid = True
                continue

            config_value = specific_config.get(k)
            if config_value is not None:
                self.variables[k] = config_value
                valid = True
                continue

        return valid

    # --------------------------------------------------------------------

    def get_value(self, variable: str):
        return self.variables.get(variable)

    def get_string(self, variable: str, default: str = None) -> str:
        v = self.get_value(variable)
        if not v:
            return default
        v = str(v).strip()
        if v == '':
            return None
        return v

    def get_int(self, variable: str, default: int = 0) -> int:
        v = self.get_value(variable)
        if not isinstance(v, int):
            return default
        return v

    def get_bool(self, variable: str, default: bool = False) -> bool:
        v = self.get_value(variable)
        if isinstance(v, bool):
            return v
        if isinstance(v, int):
            return v == 1
        if isinstance(v, str):
            return v.lower() == "true"
        return default

    def get_array(self, variable: str, default: list = []) -> list:
        v = self.get_value(variable)
        if isinstance(v, list):
            return v
        return default

    def get_object(self, variable: str, default: dict = {}) -> dict:
        v = self.get_value(variable)
        if isinstance(v, dict):
            return v
        return default

    # --------------------------------------------------------------------

    def is_enabled(self) -> bool:
        return self.get_bool('config_enabled', default=False)

    def do_start_now(self) -> bool:
        return self.get_bool('start_now', default=False)

    # --------------------------------------------------------------------

    def get_api_url(self) -> str:
        return self.get_string('api_url', default=None)

    def get_api_key(self) -> str:
        return self.get_string('api_key', default=None)

    def get_api_secret(self) -> str:
        return self.get_string('api_secret', default=None)

    def get_easydb_login(self) -> str:
        return self.get_string('easydb_user', default=None)

    def get_easydb_password(self) -> str:
        return self.get_string('easydb_password', default=None)

    # --------------------------------------------------------------------

    def get_tagfilter(self) -> search.TagFilter:
        return search.TagFilter(self.get_object('tagfilter'))

    # --------------------------------------------------------------------

    def get_api_language(self) -> str:
        return self.get_string('language', default='en-US')

    def get_objecttype(self) -> str:
        return self.get_string('objecttype', default=None)

    def check_objecttype_valid(self, datamodel_js: dict) -> None:
        ot = self.get_objecttype()
        if not ot:
            raise Exception('invalid objecttype in the configuration')

        # raises an exception if anything is wrong
        datamodel.check_objecttype_valid(datamodel_js, ot)

    def get_field_info(
        self,
        field_variable: str,
        expected: bool = False,
    ) -> datamodel.FieldInfo:
        v = self.variables.get(field_variable)
        if not v:
            if not expected:
                return None
            raise Exception(f'variable {field_variable} not set')
        info = self.field_infos.get(v)
        if not info:
            if not expected:
                return None
            raise Exception(f'field {v} not set')
        return info

    def update_field_info(
        self,
        datamodel_js: dict,
        field_variable: str,
        fieldtypes: list,
        allow_in_nested: bool = False,
        optional: bool = False,
    ) -> None:
        ot = self.get_objecttype()
        if not ot:
            raise Exception('invalid objecttype in the configuration')
        field = self.get_string(variable=field_variable, default=None)
        if not field:
            if optional:
                return None
            raise Exception(
                f'no field from variable {field_variable} in the configuration, field must exist'
            )

        # raises an exception if anything is wrong
        info = datamodel.get_field_info(
            datamodel=datamodel_js,
            objecttype=self.get_objecttype(),
            fieldname=field,
            fieldtypes=fieldtypes,
            allow_in_nested=allow_in_nested,
        )
        if not info:
            raise Exception(f'field {field} not found in datamodel')

        self.field_infos[field] = info

    def update_asset_field_info(self, datamodel_js: dict) -> None:
        self.update_field_info(
            datamodel_js,
            'asset_field',
            ['eas'],
            optional=False,
        )

    def get_asset_field_info(self) -> datamodel.FieldInfo:
        return self.get_field_info(
            'asset_field',
            expected=True,
        )

    def update_timestamp_field_info(self, datamodel_js: dict) -> None:
        self.update_field_info(
            datamodel_js,
            'timestamp_field',
            ['datetime'],
            optional=False,
        )

    def get_timestamp_field_info(self) -> datamodel.FieldInfo:
        return self.get_field_info(
            'timestamp_field',
            expected=True,
        )

    def update_mapping_fields_info(self, datamodel_js: dict) -> None:
        raise NotImplementedError(
            'update_mapping_fields_info must be implemented in the specific configuration class for each service'
        )


class CloudsightConfiguration(AiServiceConfiguration):

    def __init__(self) -> None:
        super().__init__()
        self.field_variables = [
            'subject_field',
            'keyword_gender_field',
            'keyword_quantity_field',
            'keyword_material_field',
            'keyword_color_field',
            'keyword_categories_field',
            'keyword_similar_objects_field',
        ]

    def parse_specific_variables(self, config_js: dict) -> None:
        for field_variable in self.field_variables:
            target_field = config_js.get(field_variable)
            if not target_field:
                self.variables[field_variable] = None
                continue
            self.variables[field_variable] = target_field

    def update_mapping_fields_info(self, datamodel_js: dict) -> None:
        for field_variable in self.field_variables:
            self.update_field_info(
                datamodel_js=datamodel_js,
                field_variable=field_variable,
                fieldtypes=[
                    'text',
                    'text_oneline',
                    'text_l10n',
                    'text_l10n_oneline',
                ],
                allow_in_nested=True,
                # all target fields except for the subject are optional
                optional=field_variable != 'subject_field',
            )


class DeepvaConfiguration(AiServiceConfiguration):

    def __init__(self) -> None:
        super().__init__()

    def parse_specific_variables(self, config_js: dict) -> None:
        for field_name in [
            'target_field',
            'num_labels',
        ]:
            target_field = config_js.get(field_name)
            if not target_field:
                self.variables[field_name] = None
                continue
            self.variables[field_name] = target_field

        # collect generic modules and models
        modules = {}

        # filter all fields with prefix 'module__' and collect enabled modules
        for k in config_js:
            if not k.startswith('module__'):
                continue

            module = k[len('module__') :]
            if module == '':
                continue

            module_enabled = config_js.get(k)
            if not isinstance(module_enabled, bool):
                continue
            if not module_enabled:
                continue

            modules[module] = {}

        # filter all fields with prefix 'model__' and collect models for enabled modules
        for k in config_js:
            if not k.startswith('model__'):
                continue

            module = k[len('model__') :]
            if module not in modules:
                continue

            model = config_js[k]
            if not isinstance(model, str):
                continue
            if model == '':
                continue

            modules[module]['model'] = model

        self.variables['modules'] = modules

    def update_mapping_fields_info(self, datamodel_js: dict) -> None:
        self.update_field_info(
            datamodel_js=datamodel_js,
            field_variable='target_field',
            fieldtypes=[
                'text',
                'text_oneline',
                'text_l10n',
                'text_l10n_oneline',
            ],
            allow_in_nested=True,
            optional=False,
        )


class ImaggaConfiguration(AiServiceConfiguration):

    def __init__(self) -> None:
        super().__init__()

    def parse_specific_variables(self, config_js: dict) -> None:
        for field_name in [
            'target_field',
            'num_keywords',
            'min_confidence',
        ]:
            target_field = config_js.get(field_name)
            if not target_field:
                self.variables[field_name] = None
                continue
            self.variables[field_name] = target_field

    def update_mapping_fields_info(self, datamodel_js: dict) -> None:
        self.update_field_info(
            datamodel_js=datamodel_js,
            field_variable='target_field',
            fieldtypes=[
                'text',
                'text_oneline',
                'text_l10n',
                'text_l10n_oneline',
            ],
            allow_in_nested=True,
            optional=False,
        )
