[docs] 36classStoppingCriterion: 37""" 38 State machines (SM) which decide when to stop iterative solvers by examining their mathematical state. 39 40 SM decisions are always accompanied by at least one numerical statistic. These stats may be queried by solvers via 41 :py:meth:`~pyxu.abc.StoppingCriterion.info` to provide diagnostic information to users. 42 43 Composite stopping criteria can be implemented via the overloaded (and[``&``], or[``|``]) operators. 44 """ 45
[docs] 46defstop(self,state:cabc.Mapping[str])->bool: 47""" 48 Compute a stop signal based on the current mathematical state. 49 50 Parameters 51 ---------- 52 state: ~collections.abc.Mapping 53 Full mathematical state of solver at some iteration, i.e. :py:attr:`~pyxu.abc.Solver._mstate`. 54 55 Values from `state` may be cached inside the instance to form complex stopping conditions. 56 57 Returns 58 ------- 59 s: bool 60 True if no further iterations should be performed, False otherwise. 61 """ 62raiseNotImplementedError
63
[docs] 64definfo(self)->cabc.Mapping[str,float]: 65""" 66 Get statistics associated with the last call to :py:meth:`~pyxu.abc.StoppingCriterion.stop`. 67 68 Returns 69 ------- 70 data: ~collections.abc.Mapping 71 """ 72raiseNotImplementedError
73
[docs] 74defclear(self): 75""" 76 Clear SM state (if any). 77 78 This method is useful when a :py:class:`~pyxu.abc.StoppingCriterion` instance must be reused in another call to 79 :py:meth:`~pyxu.abc.Solver.fit`. 80 """ 81pass
[docs]112classSolver:113r"""114 Iterative solver for minimization problems of the form :math:`\hat{x} = \arg\min_{x \in \mathbb{R}^{M_{1}115 \times\cdots\times M_{D}}} \mathcal{F}(x)`, where the form of :math:`\mathcal{F}` is solver-dependent.116117 Solver provides a versatile API for solving optimisation problems, with the following features:118119 * manual/automatic/background execution of solver iterations via parameters provided to120 :py:meth:`~pyxu.abc.Solver.fit`. (See below.)121 * automatic checkpointing of solver progress, providing a safe restore point in case of faulty numerical code. Each122 solver instance backs its state and final output to a folder on disk for post-analysis. In particular123 :py:meth:`~pyxu.abc.Solver.fit` should never crash: detailed exception information will always be available in a124 logfile for post-analysis.125 * arbitrary specification of complex stopping criteria via the :py:class:`~pyxu.abc.StoppingCriterion` class.126 * solve for multiple initial points in parallel. (Not always supported by all solvers.)127128 To implement a new iterative solver, users need to sub-class :py:class:`~pyxu.abc.Solver` and overwrite the methods129 below:130131 * :py:meth:`~pyxu.abc.Solver.__init__`132 * :py:meth:`~pyxu.abc.Solver.m_init` [i.e. math-init()]133 * :py:meth:`~pyxu.abc.Solver.m_step` [i.e. math-step()]134 * :py:meth:`~pyxu.abc.Solver.default_stop_crit` [optional; see method definition for details]135 * :py:meth:`~pyxu.abc.Solver.objective_func` [optional; see method definition for details]136137 Advanced functionalities of :py:class:`~pyxu.abc.Solver` are automatically inherited by sub-classes.138139140 Examples141 --------142 Here are examples on how to solve minimization problems with this class:143144 .. code-block:: python3145146 slvr = Solver()147148 ### 1. Blocking mode: .fit() does not return until solver has stopped.149 >>> slvr.fit(mode=SolverMode.BLOCK, ...)150 >>> data, hist = slvr.stats() # final output of solver.151152 ### 2. Async mode: solver iterations run in the background.153 >>> slvr.fit(mode=SolverMode.ASYNC, ...)154 >>> print('test') # you can do something in between.155 >>> slvr.busy() # or check whether the solver already stopped.156 >>> slvr.stop() # and preemptively force it to stop.157 >>> data, hist = slvr.stats() # then query the result after a (potential) force-stop.158159 ### 3. Manual mode: fine-grain control of solver data per iteration.160 >>> slvr.fit(mode=SolverMode.MANUAL, ...)161 >>> for data in slvr.steps():162 ... # Do something with the logged variables after each iteration.163 ... pass # solver has stopped after the loop.164 >>> data, hist = slvr.stats() # final output of solver.165 """166167_mstate:dict[str,typ.Any]#: Mathematical state.168_astate:dict[str,typ.Any]#: Book-keeping (non-math) state.169
[docs]170def__init__(171self,172*,173folder:pxt.Path=None,174exist_ok:bool=False,175stop_rate:pxt.Integer=1,176writeback_rate:pxt.Integer=None,177verbosity:pxt.Integer=None,178show_progress:bool=True,179log_var:pxt.VarName=frozenset(),180):181"""182 Parameters183 ----------184 folder: Path185 Directory on disk where instance data should be stored. A location will be automatically chosen if186 unspecified. (Default: OS-dependent tempdir.)187 exist_ok: bool188 If `folder` is specified and `exist_ok` is false (default), :py:class:`FileExistsError` is raised if the189 target directory already exists.190 stop_rate: Integer191 Rate at which solver evaluates stopping criteria.192 writeback_rate: Integer193 Rate at which solver checkpoints are written to disk:194195 - If `None` (default), all checkpoints are disabled: the final solver output is only stored in memory.196 - If `0`, intermediate checkpoints are disabled: only the final solver output will be written to disk.197 - Any other integer: checkpoint to disk at provided interval. Must be a multiple of `stop_rate`.198 verbosity: Integer199 Rate at which stopping criteria statistics are logged. Must be a multiple of `stop_rate`. Defaults to200 `stop_rate` if unspecified.201 show_progress: bool202 If True (default) and :py:meth:`~pyxu.abc.Solver.fit` is run with mode=BLOCK, then statistics are also203 logged to stdout.204 log_var: VarName205 Variables from the solver's math-state (:py:attr:`~pyxu.abc.Solver._mstate`) to be logged per iteration.206 These are the variables made available when calling :py:meth:`~pyxu.abc.Solver.stats`.207208 Notes209 -----210 * Partial device<>CPU synchronization takes place when stopping-criteria are evaluated. Increasing `stop_rate`211 is advised to reduce the effect of such transfers when applicable.212 * Full device<>CPU synchronization takes place at checkpoint-time. Increasing `writeback_rate` is advised to213 reduce the effect of such transfers when applicable.214 """215self._mstate=dict()216self._astate=dict(217history=None,# stopping criteria values per iteration218idx=0,# iteration index219log_rate=None,220log_var=None,221logger=None,222stdout=None,223stop_crit=None,224stop_rate=None,225track_objective=None,226wb_rate=None,227workdir=None,228# Execution-mode related -----------229mode=None,230active=None,231worker=None,232)233234try:235iffolderisNone:236folder=plib.Path(tempfile.mkdtemp(prefix="pyxu_"))237elif(folder:=plib.Path(folder).expanduser().resolve()).exists()and(notexist_ok):238raiseFileExistsError(f"{folder} already exists.")239else:240shutil.rmtree(folder,ignore_errors=True)241folder.mkdir(parents=True)242self._astate["workdir"]=folder243exceptException:244raiseException(f"folder: expected path-like, got {type(folder)}.")245246try:247assertstop_rate>=1248self._astate["stop_rate"]=int(stop_rate)249exceptException:250raiseValueError(f"stop_rate must be positive, got {stop_rate}.")251252try:253ifwriteback_rateisNone:# no checkpoints254pass255elifwriteback_rate==0:# final checkpoint only256pass257else:# regular checkpoints258assertwriteback_rate>0259assertwriteback_rate%self._astate["stop_rate"]==0260writeback_rate=int(writeback_rate)261self._astate["wb_rate"]=writeback_rate262exceptException:263raiseValueError(f"writeback_rate must be (None, 0, <multiple of stop_rate>), got {writeback_rate}.")264265try:266ifverbosityisNone:267verbosity=self._astate["stop_rate"]268assertverbosity%self._astate["stop_rate"]==0269self._astate["log_rate"]=int(verbosity)270self._astate["stdout"]=bool(show_progress)271exceptException:272raiseValueError(f"verbosity must be a multiple of stop_rate({stop_rate}), got {verbosity}.")273274try:275ifisinstance(log_var,str):276log_var=(log_var,)277self._astate["log_var"]=frozenset(log_var)278exceptException:279raiseValueError(f"log_var: expected collection, got {type(log_var)}.")
280
[docs]281deffit(self,**kwargs):282r"""283 Solve minimization problem(s) defined in :py:meth:`~pyxu.abc.Solver.__init__`, with the provided run-specifc284 parameters.285286 Parameters287 ----------288 kwargs289 See class-level docstring for class-specific keyword parameters.290 stop_crit: StoppingCriterion291 Stopping criterion to end solver iterations. If unspecified, defaults to292 :py:meth:`~pyxu.abc.Solver.default_stop_crit`.293 mode: SolverMode294 Execution mode.295 See :py:class:`~pyxu.abc.Solver` for usage examples.296297 Useful method pairs depending on the execution mode:298299 * BLOCK: :py:meth:`~pyxu.abc.Solver.fit`300 * ASYNC: :py:meth:`~pyxu.abc.Solver.fit`, :py:meth:`~pyxu.abc.Solver.busy`, :py:meth:`~pyxu.abc.Solver.stop`301 * MANUAL: :py:meth:`~pyxu.abc.Solver.fit`, :py:meth:`~pyxu.abc.Solver.steps`302 track_objective: bool303 Auto-compute objective function every time stopping criterion is evaluated.304 """305self._fit_init(306mode=kwargs.pop("mode",SolverMode.BLOCK),307stop_crit=kwargs.pop("stop_crit",None),308track_objective=kwargs.pop("track_objective",False),309)310self.m_init(**kwargs)311self._fit_run()
312
[docs]313defm_init(self,**kwargs):314"""315 Set solver's initial mathematical state based on kwargs provided to :py:meth:`~pyxu.abc.Solver.fit`.316317 This method must only manipulate :py:attr:`~pyxu.abc.Solver._mstate`.318319 After calling this method, the solver must be able to complete its 1st iteration via a call to320 :py:meth:`~pyxu.abc.Solver.m_step`.321 """322raiseNotImplementedError
323
[docs]324defm_step(self):325"""326 Perform one (mathematical) step.327328 This method must only manipulate :py:attr:`~pyxu.abc.Solver._mstate`.329 """330raiseNotImplementedError
331
[docs]332defsteps(self,n:pxt.Integer=None)->cabc.Generator:333"""334 Generator of logged variables after each iteration.335336 The i-th call to :py:func:`next` on this object returns the logged variables after the i-th solver iteration.337338 This method is only usable after calling :py:meth:`~pyxu.abc.Solver.fit` with mode=MANUAL. See339 :py:class:`~pyxu.abc.Solver` for usage examples.340341 There is no guarantee that a checkpoint on disk exists when the generator is exhausted. (Reason: potential342 exceptions raised during solver's progress.) Users should invoke :py:meth:`~pyxu.abc.Solver.writeback`343 afterwards if needed.344345 Parameters346 ----------347 n: Integer348 Maximum number of :py:func:`next` calls allowed before exhausting the generator. Defaults to infinity if349 unspecified.350351 The generator will terminate prematurely if the solver naturally stops before `n` calls to :py:func:`next`352 are made.353 """354self._check_mode(SolverMode.MANUAL)355i=0356while(nisNone)or(i<n):357ifself._step():358data,_=self.stats()359yielddata360i+=1361else:362self._astate["mode"]=None# force steps() to be call-once when exhausted.363self._cleanup_logger()364return
[docs]369defstats(self)->tuple[_stats_data_spec,_stats_history_spec]:370"""371 Query solver state.372373 Returns374 -------375 data: ~collections.abc.Mapping376 Value(s) of ``log_var`` (s) after last iteration.377 history: numpy.ndarray, None378 (N_iter,) records of stopping-criteria values sampled every `stop_rate` iteration.379380 Notes381 -----382 If any of the ``log_var`` (s) and/or ``history`` are not (yet) known at query time, ``None`` is returned.383 """384history=self._astate["history"]385ifhistoryisnotNone:386iflen(history)>0:387history=np.concatenate(history,dtype=history[0].dtype,axis=0)388else:389history=None390data={k:self._mstate.get(k)forkinself._astate["log_var"]}391returndata,history
392393@property394defworkdir(self)->pxt.Path:395"""396 Returns397 -------398 wd: Path399 Absolute path to the directory on disk where instance data is stored.400 """401returnself._astate["workdir"]402403@property404deflogfile(self)->pxt.Path:405"""406 Returns407 -------408 lf: Path409 Absolute path to the log file on disk where stopping criteria statistics are logged.410 """411returnself.workdir/"solver.log"412413@property414defdatafile(self)->pxt.Path:415"""416 Returns417 -------418 df: Path419 Absolute path to the file on disk where ``log_var`` (s) are stored during checkpointing or after solver has420 stopped.421 """422returnself.workdir/"data.zarr"423
[docs]424defbusy(self)->bool:425"""426 Test if an async-running solver has stopped.427428 This method is only usable after calling :py:meth:`~pyxu.abc.Solver.fit` with mode=ASYNC. See429 :py:class:`~pyxu.abc.Solver` for usage examples.430431 Returns432 -------433 b: bool434 True if solver has stopped, False otherwise.435 """436self._check_mode(SolverMode.ASYNC,SolverMode.BLOCK)437returnself._astate["active"].is_set()
438
[docs]439defsolution(self):440"""441 Output the "solution" of the optimization problem.442443 This is a helper method intended for novice users. The return type is sub-class dependent, so don't write an444 API using this: use :py:meth:`~pyxu.abc.Solver.stats` instead.445 """446raiseNotImplementedError
447
[docs]448defstop(self):449"""450 Stop an async-running solver.451452 This method is only usable after calling :py:meth:`~pyxu.abc.Solver.fit` with mode=ASYNC. See453 :py:class:`~pyxu.abc.Solver` for usage examples.454455 This method will block until the solver has stopped.456457 There is no guarantee that a checkpoint on disk exists once halted. (Reason: potential exceptions raised during458 solver's progress.) Users should invoke :py:meth:`~pyxu.abc.Solver.writeback` afterwards if needed.459460 Users must call this method to terminate an async-solver, even if :py:meth:`~pyxu.abc.Solver.busy` is False.461 """462self._check_mode(SolverMode.ASYNC,SolverMode.BLOCK)463self._astate["active"].clear()464self._astate["worker"].join()465self._astate.update(466mode=None,# forces stop() to be call-once.467active=None,468worker=None,469)470self._cleanup_logger()
471472def_fit_init(473self,474mode:SolverMode,475stop_crit:StoppingCriterion,476track_objective:bool,477):478def_init_logger():479log_name=str(self.workdir)480logger=logging.getLogger(log_name)481logger.handlers.clear()482logger.setLevel("DEBUG")483484fmt=logging.Formatter(fmt="{levelname} -- {message}",style="{")485handler=[logging.FileHandler(self.logfile,mode="w")]486if(modeisSolverMode.BLOCK)andself._astate["stdout"]:487handler.append(logging.StreamHandler(sys.stdout))488forhinhandler:489h.setLevel("DEBUG")490h.setFormatter(fmt)491logger.addHandler(h)492493returnlogger494495self._mstate.clear()496497ifstop_critisNone:498stop_crit=self.default_stop_crit()499stop_crit.clear()500501iftrack_objective:502frompyxu.opt.stopimportMemorize503504stop_crit|=Memorize(var="objective_func")505506self._astate.update(# suitable state for a new call to fit().507history=[],508idx=0,509logger=_init_logger(),510stop_crit=stop_crit,511track_objective=track_objective,512mode=mode,513active=None,514worker=None,515)516517def_fit_run(self):518self._m_persist()519520mode=self._astate["mode"]521ifmodeisSolverMode.MANUAL:522# User controls execution via steps().523pass524else:# BLOCK / ASYNC525self._astate.update(526active=threading.Event(),527worker=Solver._Worker(self),528)529self._astate["active"].set()530self._astate["worker"].start()531ifmodeisSolverMode.BLOCK:532self._astate["worker"].join()533self.stop()# state clean-up534else:535# User controls execution via busy() + stop().536pass537
[docs]538defwriteback(self):539"""540 Checkpoint state to disk.541 """542data,history=self.stats()543544pxu.save_zarr(self.datafile,{"history":history,**data})
545546def_check_mode(self,*modes:SolverMode):547m=self._astate["mode"]548ifminmodes:549pass# ok550else:551ifmisNone:552msg="Illegal method call: invoke Solver.fit() first."553else:554msg=" ".join(555[556"Illegal method call: can only be used if Solver.fit() invoked with",557"mode=Any["+", ".join(map(lambda_:str(_.name),modes))+"]",558]559)560raiseValueError(msg)561562def_step(self)->bool:563ast=self._astate# shorthand564565must_stop=lambda:ast["idx"]%ast["stop_rate"]==0566must_log=lambda:ast["idx"]%ast["log_rate"]==0567checkpoint_enabled=ast["wb_rate"]notin(None,0)# performing regular checkpoints?568must_writeback=lambda:checkpoint_enabledand(ast["idx"]%ast["wb_rate"]==0)569570def_log(msg:str=None):571ifmsgisNone:# report stopping-criterion values572h=ast["history"][-1][0]573msg=[f"[{dt.datetime.now()}] Iteration {ast['idx']:>_d}"]574forfield,valueinzip(h.dtype.names,h):575msg.append(f"\t{field}: {value}")576msg="\n".join(msg)577ast["logger"].info(msg)578579def_update_history():580def_as_struct(data:dict[str,float])->np.ndarray:581ftype=type(x)ifisinstance(x:=next(iter(data.values())),(int,float))elsex.dtype582583spec_data=[(k,ftype)forkindata]584585itype=np.int64586spec_iter=[("iteration",itype)]587588dtype=np.dtype(spec_iter+spec_data)589590utype=np.uint8591s=np.concatenate(# to allow mixed int/float fields:592[# (1) cast to uint, then (2) to compound dtype.593np.array([ast["idx"]],dtype=itype).view(utype),594np.array(list(data.values()),dtype=ftype).view(utype),595]596).view(dtype)597returns598599h=_as_struct(ast["stop_crit"].info())600ast["history"].append(h)601602# [Sepand] Important603# stop_crit.stop(), _update_history(), _log() must always be called in this order.604605try:606_ms,_ml,_mw=must_stop(),must_log(),must_writeback()607608if_msandast["track_objective"]:609self._mstate["objective_func"]=self.objective_func().reshape(-1)610611if_msandast["stop_crit"].stop(self._mstate):612_update_history()613_log()614_log(msg=f"[{dt.datetime.now()}] Stopping Criterion satisfied -> END")615ifast["wb_rate"]isnotNone:616self.writeback()617returnFalse618else:619if_ms:620_update_history()621if_ml:622_log()623if_mw:624self.writeback()625ast["idx"]+=1626self.m_step()627if_ms:628self._m_persist()629returnTrue630exceptExceptionase:631msg=f"[{dt.datetime.now()}] Something went wrong -> EXCEPTION RAISED"632msg_xtra=f"More information: {self.logfile}."633print("\n".join([msg,msg_xtra]),file=sys.stderr)634ifcheckpoint_enabled:# checkpointing enabled635_,r=divmod(ast["idx"],ast["wb_rate"])636idx_valid=ast["idx"]-r637msg_idx=f"Last valid checkpoint done at iteration={idx_valid}."638msg="\n".join([msg,msg_idx])639ast["logger"].exception(msg,exc_info=e)640returnFalse641642def_m_persist(self):643# Persist math state to avoid re-eval overhead.644k,v=zip(*self._mstate.items())645v=pxu.compute(*v,mode="persist",traverse=False)646self._mstate.update(zip(k,v))647# [Sepand] Note:648# The above evaluation strategy with `traverse=False` chosen since _mstate can hold any type649# of object.650651def_cleanup_logger(self):652# Close file-handlers653log_name=str(self.workdir)654logger=logging.getLogger(log_name)655forhandlerinlogger.handlers:656handler.close()657
[docs]658defdefault_stop_crit(self)->StoppingCriterion:659"""660 Default stopping criterion for solver if unspecified in :py:meth:`~pyxu.abc.Solver.fit` calls.661662 Sub-classes are expected to overwrite this method. If not overridden, then omitting the `stop_crit` parameter663 in :py:meth:`~pyxu.abc.Solver.fit` is forbidden.664 """665raiseNotImplementedError("No default stopping criterion defined.")
666
[docs]667defobjective_func(self)->pxt.NDArray:668"""669 Evaluate objective function given current math state.670671 The output array must have shape:672673 * (1,) if evaluated at 1 point,674 * (N, 1) if evaluated at N different points.675676 Sub-classes are expected to overwrite this method. If not overridden, then enabling `track_objective` in677 :py:meth:`~pyxu.abc.Solver.fit` is forbidden.678 """679raiseNotImplementedError("No objective function defined.")