Anomalib code analysis 2: under train

4. About get_experiment_logger

 What is put in get_experiment_logger?

def get_experiment_logger(
    config: DictConfig | ListConfig,
) -> Logger | Iterable[Logger] | bool:
    """Return a logger based on the choice of logger in the config file.

    Args:
        config (DictConfig): config.yaml file for the corresponding anomalib model.

    Raises:
        ValueError: for any logger types apart from false and tensorboard

    Returns:
        Logger | Iterable[Logger] | bool]: Logger
    """
    logger.info("Loading the experiment logger(s)")

    # TODO remove when logger is deprecated from project
    if "logger" in config.project.keys():
        warnings.warn(
            "'logger' key will be deprecated from 'project' section of the config file."
            " Please use the logging section in config file.",
            DeprecationWarning,
        )
        if "logging" not in config:
            config.logging = {"logger": config.project.logger, "log_graph": False}
        else:
            config.logging.logger = config.project.logger

    if config.logging.logger in (None, False):
        return False

    print("-------------------hahahhahhah")
    print(config.logging.logger)
    print("-------------------hahahhahhah-end")

    logger_list: list[Logger] = []
    if isinstance(config.logging.logger, str):
        config.logging.logger = [config.logging.logger]

    print("------------------------gao1")
    for experiment_logger in config.logging.logger:
        print("-------------------------gao2")
        print(experiment_logger)
        if experiment_logger == "tensorboard":
            logger_list.append(
                AnomalibTensorBoardLogger(
                    name="Tensorboard Logs",
                    save_dir=os.path.join(config.project.path, "logs"),
                    log_graph=config.logging.log_graph,
                )
            )
        elif experiment_logger == "wandb":
            wandb_logdir = os.path.join(config.project.path, "logs")
            Path(wandb_logdir).mkdir(parents=True, exist_ok=True)
            name = (
                config.model.name
                if "category" not in config.dataset.keys()
                else f"{config.dataset.category} {config.model.name}"
            )
            logger_list.append(
                AnomalibWandbLogger(
                    project=config.dataset.name,
                    name=name,
                    save_dir=wandb_logdir,
                )
            )
        elif experiment_logger == "comet":
            comet_logdir = os.path.join(config.project.path, "logs")
            Path(comet_logdir).mkdir(parents=True, exist_ok=True)
            run_name = (
                config.model.name
                if "category" not in config.dataset.keys()
                else f"{config.dataset.category} {config.model.name}"
            )
            logger_list.append(
                AnomalibCometLogger(project_name=config.dataset.name, experiment_name=run_name, save_dir=comet_logdir)
            )
        elif experiment_logger == "csv":
            logger_list.append(CSVLogger(save_dir=os.path.join(config.project.path, "logs")))
        else:
            raise UnknownLogger(
                f"Unknown logger type: {config.logging.logger}. "
                f"Available loggers are: {AVAILABLE_LOGGERS}.\n"
                f"To enable the logger, set `project.logger` to `true` or use one of available loggers in config.yaml\n"
                f"To disable the logger, set `project.logger` to `false`."
            )
    print("-------------------------gao3")
    print(logger_list)
    print("-------------------------gao3-end")
    return logger_list

Look at the last few lines, the logger_list I printed is a []

Toss in vain, huh, huh.

Five, about get_callbacks

callbacks = get_callbacks(config)
Check out what's inside:
def get_callbacks(config: DictConfig | ListConfig) -> list[Callback]:
    """Return base callbacks for all the lightning models.

    Args:
        config (DictConfig): Model config

    Return:
        (list[Callback]): List of callbacks.
    """
    logger.info("Loading the callbacks")

    callbacks: list[Callback] = []

    monitor_metric = None if "early_stopping" not in config.model.keys() else config.model.early_stopping.metric
    monitor_mode = "max" if "early_stopping" not in config.model.keys() else config.model.early_stopping.mode

    checkpoint = ModelCheckpoint(
        dirpath=os.path.join(config.project.path, "weights"),
        filename="model",
        monitor=monitor_metric,
        mode=monitor_mode,
        auto_insert_metric_name=False,
    )

    callbacks.extend([checkpoint, TimerCallback()])

    if "resume_from_checkpoint" in config.trainer.keys() and config.trainer.resume_from_checkpoint is not None:
        load_model = LoadModelCallback(config.trainer.resume_from_checkpoint)
        callbacks.append(load_model)

    # Add post-processing configurations to AnomalyModule.
    image_threshold = (
        config.metrics.threshold.manual_image if "manual_image" in config.metrics.threshold.keys() else None
    )
    pixel_threshold = (
        config.metrics.threshold.manual_pixel if "manual_pixel" in config.metrics.threshold.keys() else None
    )
    post_processing_callback = PostProcessingConfigurationCallback(
        threshold_method=config.metrics.threshold.method,
        manual_image_threshold=image_threshold,
        manual_pixel_threshold=pixel_threshold,
    )
    callbacks.append(post_processing_callback)

    # Add metric configuration to the model via MetricsConfigurationCallback
    metrics_callback = MetricsConfigurationCallback(
        config.dataset.task,
        config.metrics.get("image", None),
        config.metrics.get("pixel", None),
    )
    callbacks.append(metrics_callback)

    if "normalization_method" in config.model.keys() and not config.model.normalization_method == "none":
        if config.model.normalization_method == "cdf":
            if config.model.name in ("padim", "stfpm"):
                if "nncf" in config.optimization and config.optimization.nncf.apply:
                    raise NotImplementedError("CDF Score Normalization is currently not compatible with NNCF.")
                callbacks.append(CdfNormalizationCallback())
            else:
                raise NotImplementedError("Score Normalization is currently supported for PADIM and STFPM only.")
        elif config.model.normalization_method == "min_max":
            callbacks.append(MinMaxNormalizationCallback())
        else:
            raise ValueError(f"Normalization method not recognized: {config.model.normalization_method}")

    add_visualizer_callback(callbacks, config)

    if "optimization" in config.keys():
        if "nncf" in config.optimization and config.optimization.nncf.apply:
            # NNCF wraps torch's jit which conflicts with kornia's jit calls.
            # Hence, nncf is imported only when required
            nncf_module = import_module("anomalib.utils.callbacks.nncf.callback")
            nncf_callback = getattr(nncf_module, "NNCFCallback")
            nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf))
            callbacks.append(
                nncf_callback(
                    config=nncf_config,
                    export_dir=os.path.join(config.project.path, "compressed"),
                )
            )
        if config.optimization.export_mode is not None:
            from .export import (  # pylint: disable=import-outside-toplevel
                ExportCallback,
            )

            logger.info("Setting model export to %s", config.optimization.export_mode)
            callbacks.append(
                ExportCallback(
                    input_size=config.model.input_size,
                    dirpath=config.project.path,
                    filename="model",
                    export_mode=ExportMode(config.optimization.export_mode),
                )
            )
        else:
            warnings.warn(f"Export option: {config.optimization.export_mode} not found. Defaulting to no model export")

    # Add callback to log graph to loggers
    if config.logging.log_graph not in (None, False):
        callbacks.append(GraphLogger())

    print("-------------gao callbacks")
    print(callbacks)
    print("-------------gao callbacks-end")

    return callbacks

Look at the callbacks I printed out at the end, what does it look like?

6. About Trainer

Finally came to the most important place:

 The first parameter of Trainer is as follows:

 The second parameter, as we said earlier, is []

The third parameter is the bunch of callbacks above

 

Guess you like

Origin blog.csdn.net/gaoenyang760525/article/details/129849276