feat: 使算法能被自动注册

This commit is contained in:
2026-05-04 13:55:57 +08:00
parent e2b9fb94f3
commit ca86b2d8e9
11 changed files with 101 additions and 139 deletions

View File

@@ -1,21 +1,14 @@
import importlib
import pkgutil
from pathlib import Path
from .base import BaseAlgorithm
from .sm2 import SM2Algorithm
from .sm15m import SM15MAlgorithm
from .nsp0 import NSP0Algorithm
from .fsrs import FSRSAlgorithm
__all__ = [
"SM2Algorithm",
"BaseAlgorithm",
"SM15MAlgorithm",
"NSP0Algorithm",
"FSRSAlgorithm",
]
__path__ = [str(Path(__file__).parent)]
algorithms = {
"SM-2": SM2Algorithm,
"NSP-0": NSP0Algorithm,
"SM-15M": SM15MAlgorithm,
"FSRS": FSRSAlgorithm,
"Base": BaseAlgorithm,
}
for _finder, _name, _ispkg in pkgutil.iter_modules(__path__):
if _name == "base":
continue
importlib.import_module(f".{_name}", __package__)
algorithms = BaseAlgorithm.get_registry()

View File

@@ -5,11 +5,20 @@ from heurams.services.logger import get_logger
logger = get_logger(__name__)
_registry: dict[str, type["BaseAlgorithm"]] = {}
class BaseAlgorithm:
algo_name = "BaseAlgorithm"
desc = "算法基类"
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
_registry[cls.algo_name] = cls
@classmethod
def get_registry(cls) -> dict[str, type["BaseAlgorithm"]]:
return dict(_registry)
class AlgodataDict(TypedDict):
real_rept: int
rept: int

View File

@@ -10,7 +10,7 @@ logger = get_logger(__name__)
class NSP0Algorithm(BaseAlgorithm):
algo_name = "NSP-0"
desc = "快速筛选用特殊调度器"
desc = "快速筛选用非间隔重复调度器"
class AlgodataDict(TypedDict):
real_rept: int

View File

@@ -10,7 +10,7 @@ logger = get_logger(__name__)
class SM2Algorithm(BaseAlgorithm):
algo_name = "SM-2"
desc = "经典间隔重复算法"
desc = "SuperMemo2 (1987) 简单间隔重复调度器"
class AlgodataDict(TypedDict):
efactor: float