feat: 使算法能被自动注册

This commit is contained in:
2026-05-04 13:55:57 +08:00
parent e2b9fb94f3
commit ca86b2d8e9
11 changed files with 101 additions and 139 deletions

View File

@@ -12,8 +12,8 @@ python 代指您使用的解释器, 在某些发行版中可能是 python3, 而
尽管项目保留了 requirements.txt, 我们仍不推荐使用系统 python 和原始 venv 进行开发.
项目的推荐开发环境工具是 uv.
如果你的环境已经安装了 uv:
先运行 uv sync 同步环境, 此命令只需要执行一遍, uv 会自动处理依赖.
然后通过运行 uv run tui 启动内置基本用户界面.
先运行 uv sync --all-extras 同步环境, 此命令只需要执行一遍, uv 会自动处理依赖.
然后通过运行 uv run heurams-tui 启动内置基本用户界面.
此时您的解释器在项目目录里的 .venv/bin 中, 使用 IDE 开发前, 务必切换解释器!
注意: 一个常见的误区是, 执行 interface 下的 __main__.py 运行基本用户界面, 这会导致 Python 上下文环境异常, 请不要这样做."""
print(prompt)

View File

@@ -1,21 +1,14 @@
import importlib
import pkgutil
from pathlib import Path
from .base import BaseAlgorithm
from .sm2 import SM2Algorithm
from .sm15m import SM15MAlgorithm
from .nsp0 import NSP0Algorithm
from .fsrs import FSRSAlgorithm
__all__ = [
"SM2Algorithm",
"BaseAlgorithm",
"SM15MAlgorithm",
"NSP0Algorithm",
"FSRSAlgorithm",
]
__path__ = [str(Path(__file__).parent)]
algorithms = {
"SM-2": SM2Algorithm,
"NSP-0": NSP0Algorithm,
"SM-15M": SM15MAlgorithm,
"FSRS": FSRSAlgorithm,
"Base": BaseAlgorithm,
}
for _finder, _name, _ispkg in pkgutil.iter_modules(__path__):
if _name == "base":
continue
importlib.import_module(f".{_name}", __package__)
algorithms = BaseAlgorithm.get_registry()

View File

@@ -5,11 +5,20 @@ from heurams.services.logger import get_logger
logger = get_logger(__name__)
_registry: dict[str, type["BaseAlgorithm"]] = {}
class BaseAlgorithm:
algo_name = "BaseAlgorithm"
desc = "算法基类"
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
_registry[cls.algo_name] = cls
@classmethod
def get_registry(cls) -> dict[str, type["BaseAlgorithm"]]:
return dict(_registry)
class AlgodataDict(TypedDict):
real_rept: int
rept: int

View File

@@ -10,7 +10,7 @@ logger = get_logger(__name__)
class NSP0Algorithm(BaseAlgorithm):
algo_name = "NSP-0"
desc = "快速筛选用特殊调度器"
desc = "快速筛选用非间隔重复调度器"
class AlgodataDict(TypedDict):
real_rept: int

View File

@@ -10,7 +10,7 @@ logger = get_logger(__name__)
class SM2Algorithm(BaseAlgorithm):
algo_name = "SM-2"
desc = "经典间隔重复算法"
desc = "SuperMemo2 (1987) 简单间隔重复调度器"
class AlgodataDict(TypedDict):
efactor: float