"""File reading and writing utilities for Neuroevolution fitting."""importpicklefrompathlibimportPathfromtypingimportAnnotatedasAnfromcommon.optim.ne.agentimportBaseAgentfromcommon.optim.ne.utils.typeimportGeneration_results_typefromcommon.utils.beartypeimportgefromcommon.utils.mpi4pyimportget_mpi_variables
[docs]deffind_existing_save_points(output_dir:str)->list[int]:"""Returns a list of existing save points. Args: output_dir: See :paramref:`~.BaseSubtaskConfig.output_dir`. Returns: The list of existing save points. """return[int(save_path.name)forsave_pathinPath(output_dir).glob(pattern="*")if(save_path.is_dir()and(save_path/"state.pkl").exists()andsave_path.name.isdigit())]
[docs]defload_state(prev_num_gens:An[int,ge(0)],len_agents_batch:An[int,ge(1)],output_dir:str,)->tuple[list[list[BaseAgent]],# agents_batchGeneration_results_type|None,# generation_resultsAn[int,ge(0)]|None,# total_num_env_steps]:"""Load a previous experiment state from disk. Args: prev_num_gens: See :paramref:`~.NeuroevolutionSubtaskConfig.prev_num_gens`. len_agents_batch: See :paramref:`~.initialize_agents.len_agents_batch`. output_dir: See :paramref:`~.BaseSubtaskConfig.output_dir`. Returns: * See ~.compute_generation_results.agents_batch`. * See :paramref:`~.compute_generation_results.generation_results`. * See :paramref:`~.compute_total_num_env_steps_and_process_fitnesses.total_num_env_steps`. """comm,rank,size=get_mpi_variables()ifrank==0:path=Path(f"{output_dir}/{prev_num_gens}/state.pkl")ifnotpath.exists():error_msg=f"No saved state found at {path}."raiseFileNotFoundError(error_msg)withpath.open(mode="rb")asf:state=pickle.load(file=f)agents:list[list[BaseAgent]]=state[0]generation_results:Generation_results_type=state[1]total_num_env_steps:int=state[2]batched_agents:list[list[list[BaseAgent]]]=[agents[i*len_agents_batch:(i+1)*len_agents_batch]foriinrange(size)]# `comm.scatter` argument `sendobj` is wrongly typed. `[]` is the# workaround for not being able to set it to `None`.# See https://github.com/mpi4py/mpi4py/issues/434agents_batch=comm.scatter(sendobj=[]ifrank!=0elsebatched_agents)return(agents_batch,Noneifrank!=0elsegeneration_results,Noneifrank!=0elsetotal_num_env_steps,)
[docs]defsave_state(agents_batch:list[list[BaseAgent]],generation_results:Generation_results_type|None,total_num_env_steps:An[int,ge(0)]|None,curr_gen:An[int,ge(1)],output_dir:str,)->None:"""Dump the current experiment state to disk. Args: agents_batch: See :paramref:`~.compute_generation_results.agents_batch`. generation_results: See :paramref:`~.compute_generation_results.generation_results`. total_num_env_steps: See :paramref:`~.compute_total_num_env_steps_and_process_fitnesses.total_num_env_steps`. curr_gen: See :paramref:`~.BaseSpace.curr_gen`. output_dir: See :paramref:`~.BaseSubtaskConfig.output_dir`. """comm,rank,_=get_mpi_variables()batched_agents:list[list[list[BaseAgent]]]|None=comm.gather(sendobj=agents_batch,)ifrank!=0:return# `batched_agents`, `generation_results`, and `total_num_env_steps`# are only `None` when `rank != 0`. The following `assert`# statements are for static type checking reasons and have no# execution purposes.assertbatched_agentsisnotNone# noqa: S101assertgeneration_resultsisnotNone# noqa: S101asserttotal_num_env_stepsisnotNone# noqa: S101agents:list[list[BaseAgent]]=[]foragent_batchinbatched_agents:agents=agents+agent_batchpath=Path(f"{output_dir}/{curr_gen}/state.pkl")ifnotpath.parent.exists():path.parent.mkdir(parents=True)withpath.open(mode="wb")asf:pickle.dump(obj=[agents,generation_results,total_num_env_steps],file=f,)