"""`W&B <https://wandb.ai/>`_ utilities for Neuroevolution fitting."""fromcollections.abcimportCallablefromtypingimportAnyimportnumpyasnpimportwandbfromomegaconfimportOmegaConffromcommon.utils.mpi4pyimportget_mpi_variables
[docs]defsetup_wandb(logger:Callable[...,Any],output_dir:str)->None:"""Sets up `W&B <https://wandb.ai/>`_ logging for all MPI processes. Args: logger: See :func:`wandb.init`. output_dir: See :paramref:`~.BaseSubtaskConfig.output_dir`. """comm,rank,_=get_mpi_variables()ifrank!=0:returnlogger(config=OmegaConf.to_container(OmegaConf.load(f"{output_dir}/.hydra/config.yaml"),resolve=True,throw_on_missing=True,),)
[docs]defgather(logged_score:float|None,curr_gen:int,agent_total_num_steps:int,)->None:"""Gathers logged scores from all MPI processes. Args: logged_score: A value logged during evaluation. If ``None``, then no value was logged during evaluation. curr_gen: See :paramref:`~.BaseSpace.curr_gen`. agent_total_num_steps: See :attr:`~.BaseAgent.total_num_steps`. """comm,rank,_=get_mpi_variables()logged_scores:list[float|None]|None=comm.gather(sendobj=logged_score,)logged_agent_total_num_steps:list[int]|None=comm.gather(sendobj=agent_total_num_steps,)ifrank!=0:return# `logged_scores` & `logged_agent_total_num_steps` are only `None`# when `rank != 0`. The following `assert` statements are for static# type checking reasons and have no execution purposes.assertlogged_scoresisnotNone# noqa: S101assertlogged_agent_total_num_stepsisnotNone# noqa: S101non_none_logged_scores:list[float]=list(filter(None,logged_scores))non_none_logged_scores_mean=np.mean(a=non_none_logged_scores)logged_agent_total_num_steps_mean=np.mean(a=logged_agent_total_num_steps,)wandb.log(data={"score":non_none_logged_scores_mean,"num_steps":logged_agent_total_num_steps_mean,"gen":curr_gen,},)