Files
HeurAMS/src/heurams/kernel/algorithms/sm15m.py

834 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
SM-15M — 基于 sm.js 的间隔重复算法
基于: https://github.com/slaypni/sm.js
原始 CoffeeScript (c) 2014 Kazuaki Tanida, MIT 许可证
"""
import datetime
import json
import math
import os
import pathlib
from typing import TypedDict
from heurams.context import config_var
from heurams.services.logger import get_logger
from heurams.services.timer import (
get_daystamp,
get_timestamp,
get_timestamp_ms,
daystamp_to_datetime,
datetime_to_daystamp,
get_now_datetime,
)
from .base import BaseAlgorithm
logger = get_logger(__name__)
# ============================================================================
# Constants
# ============================================================================
RANGE_AF = 20
RANGE_REPETITION = 20
MIN_AF = 1.2
NOTCH_AF = 0.3
MAX_AF = MIN_AF + NOTCH_AF * (RANGE_AF - 1) # 6.9
MAX_GRADE = 5
THRESHOLD_RECALL = 3
# ============================================================================
# Math Helpers
# ============================================================================
def sum_values(values):
return sum(values)
def exponential_regression(points):
"""y = a * exp(b * x)"""
n = len(points)
X = [p[0] for p in points]
Y = [p[1] for p in points]
logY = [math.log(y) for y in Y]
sqX = [x * x for x in X]
sum_logY = sum(logY)
sum_sqX = sum(sqX)
sumX = sum(X)
sumX_logY = sum(X[i] * logY[i] for i in range(n))
sq_sumX = sumX * sumX
denom = n * sum_sqX - sq_sumX
a_coeff = (sum_logY * sum_sqX - sumX * sumX_logY) / denom if denom else 0
b_coeff = (n * sumX_logY - sumX * sum_logY) / denom if denom else 0
a = math.exp(a_coeff)
def y_func(x):
return a * math.exp(b_coeff * x)
def x_func(y):
if b_coeff == 0:
return 0
return (-a_coeff + math.log(y)) / b_coeff
return {"y": y_func, "x": x_func, "a": a, "b": b_coeff}
def linear_regression(points):
"""y = a + b * x"""
n = len(points)
X = [p[0] for p in points]
Y = [p[1] for p in points]
sumX = sum(X)
sumY = sum(Y)
sumXY = sum(X[i] * Y[i] for i in range(n))
sum_sqX = sum(x * x for x in X)
sq_sumX = sumX * sumX
denom = n * sum_sqX - sq_sumX
a = (sumY * sum_sqX - sumX * sumXY) / denom if denom else 0
b = (n * sumXY - sumX * sumY) / denom if denom else 0
def y_func(x):
return a + b * x
def x_func(y):
if b == 0:
return 0
return (y - a) / b
return {"y": y_func, "x": x_func, "a": a, "b": b}
def power_law_model(a, b):
"""y = a * x^b"""
def y_func(x):
return a * (x ** b)
def x_func(y):
if a == 0 or b == 0:
return 0
return (y / a) ** (1.0 / b)
return {"y": y_func, "x": x_func, "a": a, "b": b}
def fixed_point_power_law_regression(points, fixed_point):
"""y = q * (x/p)^b, through fixed point (p, q)"""
n = len(points)
p, q = fixed_point
log_q = math.log(q)
X = [math.log(point[0] / p) for point in points]
Y = [math.log(point[1]) - log_q for point in points]
sumXY = sum(X[i] * Y[i] for i in range(n))
sum_sqX = sum(x * x for x in X)
b = sumXY / sum_sqX if sum_sqX else 0
return power_law_model(q / (p ** b), b)
# ============================================================================
# FI-G Grade Graph
# ============================================================================
class FI_G:
"""Forgetting Index — Grade graph (exponential regression)."""
MAX_POINTS_COUNT = 5000
GRADE_OFFSET = 1
def __init__(self, sm, points=None):
self.sm = sm
self._graph = None
if points is not None:
self.points = points
else:
self.points = []
self._register_point(0, MAX_GRADE)
self._register_point(100, 0)
def _register_point(self, fi, grade):
self.points.append([fi, grade + self.GRADE_OFFSET])
if len(self.points) > self.MAX_POINTS_COUNT:
self.points = self.points[-self.MAX_POINTS_COUNT:]
self._graph = None
def update(self, grade, item, now):
def expected_fi():
return (item.uf(now) / item.of) * self.sm.requested_fi
self._register_point(expected_fi(), grade)
def fi(self, grade):
if not self.points:
return 50.0
if self._graph is None:
self._graph = exponential_regression(self.points)
return max(0.0, min(100.0, self._graph["x"](grade + self.GRADE_OFFSET)))
def data(self):
return {"points": self.points}
@classmethod
def load(cls, sm, data):
return cls(sm, data.get("points"))
# ============================================================================
# Forgetting Curve (single cell)
# ============================================================================
class ForgettingCurve:
MAX_POINTS_COUNT = 500
FORGOTTEN = 1
REMEMBERED = 100 + FORGOTTEN # 101
def __init__(self, points):
self.points = points
self._curve = None
def register_point(self, grade, uf):
is_remembered = grade >= THRESHOLD_RECALL
self.points.append([uf, self.REMEMBERED if is_remembered else self.FORGOTTEN])
if len(self.points) > self.MAX_POINTS_COUNT:
self.points = self.points[-self.MAX_POINTS_COUNT:]
self._curve = None
def retention(self, uf):
if not self.points:
return 50.0
if self._curve is None:
self._curve = exponential_regression(self.points)
clamped = max(self.FORGOTTEN, min(self._curve["y"](uf), self.REMEMBERED))
return clamped - self.FORGOTTEN
def uf(self, retention):
if not self.points:
return 1.0
if self._curve is None:
self._curve = exponential_regression(self.points)
return max(0.0, self._curve["x"](retention + self.FORGOTTEN))
# ============================================================================
# Forgetting Curves Matrix (repetition × af)
# ============================================================================
class ForgettingCurves:
FORGOTTEN = 1
REMEMBERED = 100 + FORGOTTEN
def __init__(self, sm, points=None):
self.sm = sm
self.curves = []
for r in range(RANGE_REPETITION):
row = []
for a in range(RANGE_AF):
if points is not None:
partial = points[r][a]
else:
if r > 0:
pts = []
for i in range(21):
v = MIN_AF + NOTCH_AF * i
y = math.exp(
-(r + 1) / 200 * (i - a * math.sqrt(2.0 / (r + 1)))
) * (self.REMEMBERED - self.sm.requested_fi)
pts.append([v, min(self.REMEMBERED, y)])
partial = [[0, self.REMEMBERED]] + pts
else:
pts = []
for i in range(21):
v = MIN_AF + NOTCH_AF * i
y = math.exp(
-1.0 / (10 + 1 * (a + 1)) * (i - a ** 0.6)
) * (self.REMEMBERED - self.sm.requested_fi)
pts.append([v, min(self.REMEMBERED, y)])
partial = [[0, self.REMEMBERED]] + pts
row.append(ForgettingCurve(partial))
self.curves.append(row)
def register_point(self, grade, item, now):
af_index = item.af_index() if item.repetition > 0 else item.lapse
self.curves[item.repetition][af_index].register_point(grade, item.uf(now))
def data(self):
return {
"points": [
[self.curves[r][a].points for a in range(RANGE_AF)]
for r in range(RANGE_REPETITION)
]
}
@classmethod
def load(cls, sm, data):
return cls(sm, data.get("points"))
# ============================================================================
# R-Factor Matrix
# ============================================================================
class RFM:
def __init__(self, sm):
self.sm = sm
def rf(self, repetition, af_index):
return self.sm.forgetting_curves.curves[repetition][af_index].uf(
100 - self.sm.requested_fi
)
# ============================================================================
# O-Factor Matrix
# ============================================================================
class OFM:
INITIAL_REP_VALUE = 1
def __init__(self, sm):
self.sm = sm
self._ofm = None
self._ofm0 = None
self.update()
def update(self):
def af_from_index(a):
return a * NOTCH_AF + MIN_AF
def rep_from_index(r):
return r + self.INITIAL_REP_VALUE
# D-factors: power law decay along repetition axis
dfs = []
for a in range(RANGE_AF):
pts = [
[rep_from_index(r), self.sm.rfm.rf(r, a)]
for r in range(1, RANGE_REPETITION)
]
fp = [rep_from_index(1), af_from_index(a)]
model = fixed_point_power_law_regression(pts, fp)
dfs.append(model["b"])
# Transform D-factors
dfs_t = [af_from_index(a) / (2.0 ** dfs[a]) for a in range(RANGE_AF)]
# Linear regression of D-factor by A-Factor index
decay_pts = [[a, dfs_t[a]] for a in range(RANGE_AF)]
decay = linear_regression(decay_pts)
# Build O-Factor model per A-Factor
ofm_list = []
for a in range(RANGE_AF):
af = af_from_index(a)
d_y = decay["y"](a)
b = math.log(af / d_y) / math.log(rep_from_index(1)) if d_y != 0 else 0
model = power_law_model(af / (rep_from_index(1) ** b), b)
def make_ofm_funcs(m):
return {
"y": lambda r, m=m: m["y"](rep_from_index(r)),
"x": lambda y, m=m: m["x"](y) - self.INITIAL_REP_VALUE,
}
ofm_list.append(make_ofm_funcs(model))
self._ofm = ofm_list
# O-Factor for repetition 0
ofm0_pts = [[a, self.sm.rfm.rf(0, a)] for a in range(RANGE_AF)]
ofm0 = exponential_regression(ofm0_pts)
self._ofm0 = lambda a: ofm0["y"](a)
def of(self, repetition, af_index):
if repetition == 0:
return self._ofm0(af_index)
return self._ofm[af_index]["y"](repetition)
def af(self, repetition, of_val):
for a in range(RANGE_AF):
if abs(self.of(repetition, a) - of_val) < 1e-10:
return a * NOTCH_AF + MIN_AF
# Find closest
best, best_a = float("inf"), 0
for a in range(RANGE_AF):
d = abs(self.of(repetition, a) - of_val)
if d < best:
best, best_a = d, a
return best_a * NOTCH_AF + MIN_AF
# ============================================================================
# Item (per-card state)
# ============================================================================
class Item:
MAX_AFS_COUNT = 30
def __init__(self, sm, value=None):
self.sm = sm
self.value = value
self.lapse = 0
self.repetition = -1
self.of = 1.0
self.optimum_interval = sm.interval_base # ms
self.due_date = datetime.datetime(1970, 1, 1)
self.previous_date = None
self._afs = []
self._af = None
def interval(self, now=None):
if now is None:
now = datetime.datetime.now()
if self.previous_date is None:
return self.sm.interval_base
return (now - self.previous_date).total_seconds() * 1000
def uf(self, now=None):
if now is None:
now = datetime.datetime.now()
adjusted = self.optimum_interval / self.of
return self.interval(now) / adjusted if adjusted else 0
def af(self, value=None):
if value is None:
return self._af
a = round((value - MIN_AF) / NOTCH_AF)
self._af = max(MIN_AF, min(MAX_AF, MIN_AF + a * NOTCH_AF))
return self._af
def af_index(self):
target = self.af() if self._af is not None else MIN_AF
best, best_i = float("inf"), 0
for i in range(RANGE_AF):
d = abs(target - (MIN_AF + i * NOTCH_AF))
if d < best:
best, best_i = d, i
return best_i
def _I(self, now=None):
if now is None:
now = datetime.datetime.now()
af_idx = self.lapse if self.repetition == 0 else self.af_index()
of_val = self.sm.ofm.of(self.repetition, af_idx)
self.of = max(1.0, (of_val - 1) * (self.interval(now) / self.optimum_interval) + 1)
self.optimum_interval = round(self.optimum_interval * self.of)
self.previous_date = now
self.due_date = now + datetime.timedelta(milliseconds=self.optimum_interval)
def _update_af(self, grade, now=None):
if now is None:
now = datetime.datetime.now()
estimated_fi = max(1.0, self.sm.fi_g.fi(grade))
corrected_uf = self.uf(now) * (self.sm.requested_fi / estimated_fi)
if self.repetition > 0:
estimated_af = self.sm.ofm.af(self.repetition, corrected_uf)
else:
estimated_af = max(MIN_AF, min(MAX_AF, corrected_uf))
self._afs.append(estimated_af)
if len(self._afs) > self.MAX_AFS_COUNT:
self._afs = self._afs[-self.MAX_AFS_COUNT:]
wsum = sum(af * (i + 1) for i, af in enumerate(self._afs))
wtotal = sum(range(1, len(self._afs) + 1))
self.af(wsum / wtotal if wtotal else estimated_af)
def answer(self, grade, now=None):
if now is None:
now = datetime.datetime.now()
if self.repetition >= 0:
self._update_af(grade, now)
if grade >= THRESHOLD_RECALL:
if self.repetition < RANGE_REPETITION - 1:
self.repetition += 1
self._I(now)
else:
if self.lapse < RANGE_AF - 1:
self.lapse += 1
self.optimum_interval = self.sm.interval_base
self.previous_date = None
self.due_date = now
self.repetition = -1
def data(self):
return {
"value": self.value,
"repetition": self.repetition,
"lapse": self.lapse,
"of": self.of,
"optimumInterval": self.optimum_interval,
"dueDate": self.due_date.isoformat(),
"previousDate": (
self.previous_date.isoformat() if self.previous_date else None
),
"_afs": self._afs,
}
@classmethod
def load(cls, sm, data):
item = cls(sm)
item.value = data.get("value")
item.repetition = data.get("repetition", -1)
item.lapse = data.get("lapse", 0)
item.of = data.get("of", 1.0)
item.optimum_interval = data.get("optimumInterval", sm.interval_base)
item._afs = data.get("_afs", [])
due_str = data.get("dueDate")
if due_str:
if isinstance(due_str, str):
item.due_date = datetime.datetime.fromisoformat(
due_str.replace("Z", "+00:00")
)
else:
item.due_date = datetime.datetime.fromtimestamp(due_str / 1000)
prev_str = data.get("previousDate")
if prev_str:
if isinstance(prev_str, str):
item.previous_date = datetime.datetime.fromisoformat(
prev_str.replace("Z", "+00:00")
)
else:
item.previous_date = datetime.datetime.fromtimestamp(prev_str / 1000)
if item._afs:
item.af(sum(item._afs) / len(item._afs))
return item
# ============================================================================
# SM (global scheduler)
# ============================================================================
class SM:
def __init__(self):
self.requested_fi = 10.0
self.interval_base = 3 * 60 * 60 * 1000 # 3 hours in ms
self.q = []
self.fi_g = FI_G(self)
self.forgetting_curves = ForgettingCurves(self)
self.rfm = RFM(self)
self.ofm = OFM(self)
def _find_index_to_insert(self, item, r=None):
if r is None:
r = list(range(len(self.q)))
if not r:
return 0
v = item.due_date
i = len(r) // 2
if len(r) == 1:
return r[i] if v < self.q[r[i]].due_date else r[i] + 1
if v < self.q[r[i]].due_date:
return self._find_index_to_insert(item, r[:i])
return self._find_index_to_insert(item, r[i:])
def add_item(self, value):
item = Item(self, value)
idx = self._find_index_to_insert(item)
self.q.insert(idx, item)
def next_item(self, is_advanceable=False):
if not self.q:
return None
now = datetime.datetime.now()
if is_advanceable or self.q[0].due_date < now:
return self.q[0]
return None
def answer(self, grade, item, now=None):
if now is None:
now = datetime.datetime.now()
self._update(grade, item, now)
self.discard(item)
idx = self._find_index_to_insert(item)
self.q.insert(idx, item)
def _update(self, grade, item, now=None):
if now is None:
now = datetime.datetime.now()
if item.repetition >= 0:
self.forgetting_curves.register_point(grade, item, now)
self.ofm.update()
self.fi_g.update(grade, item, now)
item.answer(grade, now)
def discard(self, item):
if item in self.q:
self.q.remove(item)
def data(self):
return {
"requestedFI": self.requested_fi,
"intervalBase": self.interval_base,
"q": [item.data() for item in self.q],
"fi_g": self.fi_g.data(),
"forgettingCurves": self.forgetting_curves.data(),
"version": 1,
}
@classmethod
def load(cls, data):
sm = cls()
sm.requested_fi = data.get("requestedFI", 10.0)
sm.interval_base = data.get("intervalBase", 3 * 60 * 60 * 1000)
sm.q = [Item.load(sm, d) for d in data.get("q", [])]
sm.fi_g = FI_G.load(sm, data.get("fi_g", {}))
sm.forgetting_curves = ForgettingCurves.load(
sm, data.get("forgettingCurves", {})
)
sm.rfm = RFM(sm)
sm.ofm = OFM(sm)
return sm
# ============================================================================
# Global state management
# ============================================================================
_GLOBAL_STATE_FILE = pathlib.Path(
config_var.get()["global"]["paths"]["data"]
) / "global" / "sm15m_global_state.json"
def _get_global_sm():
if os.path.exists(_GLOBAL_STATE_FILE):
try:
with open(_GLOBAL_STATE_FILE, "r", encoding="utf-8") as f:
return SM.load(json.load(f))
except Exception:
logger.warning("SM-15M 全局状态文件加载失败,创建新实例")
sm = SM()
_save_global_sm(sm)
return sm
def _save_global_sm(sm):
try:
_GLOBAL_STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(_GLOBAL_STATE_FILE, "w", encoding="utf-8") as f:
json.dump(sm.data(), f, indent=2)
except Exception:
logger.exception("SM-15M 全局状态保存失败")
# ============================================================================
# SM15MAlgorithm (HeurAMS interface)
# ============================================================================
class SM15MAlgorithm(BaseAlgorithm):
algo_name = "SM-15M"
desc = "基于 sm.js 的 SM-15 间隔重复算法"
class AlgodataDict(TypedDict):
# SM-15M 特有
lapse: int
repetition: int
of_val: float # O-Factor
optimum_interval_days: int # 最优间隔(天)
afs: list # A-Factor 历史
af: float # 当前 A-Factor
# 毫秒精度(子日排程)
last_date_ms: int
next_date_ms: int
# BaseAlgorithm 兼容(天精度,向后兼容)
real_rept: int
rept: int
interval: int
last_date: int
next_date: int
is_activated: int
last_modify: float
defaults = {
"lapse": 0,
"repetition": -1,
"of_val": 1.0,
"optimum_interval_days": 0,
"afs": [],
"af": 0.0,
"real_rept": 0,
"rept": 0,
"interval": 0,
"last_date": 0,
"next_date": 0,
"is_activated": 0,
# 毫秒精度字段
"last_date_ms": 0,
"next_date_ms": 0,
"last_modify": get_timestamp(),
}
@classmethod
def _algodata_to_item(cls, algodata, sm):
data = algodata.get(cls.algo_name, cls.defaults.copy())
item = Item(sm)
item.repetition = data.get("repetition", -1)
item.lapse = data.get("lapse", 0)
item.of = data.get("of_val", 1.0)
item._afs = list(data.get("afs", []))
af = data.get("af", 0.0)
if af > 0:
item.af(af)
if item._afs:
if item._af is None and item._afs:
item.af(sum(item._afs) / len(item._afs))
opt_days = data.get("optimum_interval_days", 0)
item.optimum_interval = (
opt_days * 24 * 60 * 60 * 1000 if opt_days > 0 else sm.interval_base
)
# 毫秒精度优先,退化至天精度
last_date_ms = data.get("last_date_ms", 0)
if last_date_ms:
item.previous_date = datetime.datetime(1970, 1, 1) + datetime.timedelta(
milliseconds=last_date_ms
)
else:
last_date = data.get("last_date", 0)
item.previous_date = (
daystamp_to_datetime(last_date).replace(tzinfo=None)
if last_date > 0 else None
)
next_date_ms = data.get("next_date_ms", 0)
if next_date_ms:
item.due_date = datetime.datetime(1970, 1, 1) + datetime.timedelta(
milliseconds=next_date_ms
)
else:
next_date = data.get("next_date", 0)
item.due_date = (
daystamp_to_datetime(next_date).replace(tzinfo=None)
if next_date > 0
else datetime.datetime(1970, 1, 1)
)
item.value = {"_restored": True}
return item
@classmethod
def _item_to_algodata(cls, item, algodata):
if cls.algo_name not in algodata:
algodata[cls.algo_name] = cls.defaults.copy()
data = algodata[cls.algo_name]
data["lapse"] = item.lapse
data["repetition"] = item.repetition
data["of_val"] = item.of
opt_ms = max(item.optimum_interval, 0)
data["optimum_interval_days"] = round(opt_ms / (24 * 60 * 60 * 1000))
data["afs"] = list(item._afs)
data["af"] = item.af() if item._af is not None else 0.0
# 毫秒精度
if item.previous_date:
data["last_date_ms"] = int(item.previous_date.timestamp() * 1000)
data["last_date"] = datetime_to_daystamp(item.previous_date)
data["next_date_ms"] = int(item.due_date.timestamp() * 1000)
data["next_date"] = datetime_to_daystamp(item.due_date)
data["interval"] = max(
0, data["next_date"] - (data.get("last_date", 0) or 0)
)
data["last_modify"] = get_timestamp()
return algodata
@classmethod
def revisor(
cls, algodata: dict, feedback: int = 5, is_new_activation: bool = False
):
logger.debug(
"SM-15M.revisor 开始, feedback=%d, is_new_activation=%s",
feedback,
is_new_activation,
)
if feedback == -1:
return
sm = _get_global_sm()
item = cls._algodata_to_item(algodata, sm)
if is_new_activation:
item.repetition = -1
item.lapse = 0
item.of = 1.0
item.optimum_interval = sm.interval_base
item.previous_date = None
item.due_date = datetime.datetime(1970, 1, 1)
item._afs = []
item._af = None
item.af(2.5)
sm.answer(feedback, item)
_save_global_sm(sm)
cls._item_to_algodata(item, algodata)
algodata[cls.algo_name]["real_rept"] += 1
if feedback >= THRESHOLD_RECALL:
algodata[cls.algo_name]["rept"] += 1
logger.debug(
"SM-15M.revisor 完成: repetition=%d, of=%.4f, next_date=%d",
item.repetition,
item.of,
algodata[cls.algo_name]["next_date"],
)
@classmethod
def is_due(cls, algodata):
data = algodata.get(cls.algo_name, cls.defaults.copy())
# 毫秒精度优先
next_date_ms = data.get("next_date_ms", 0)
if next_date_ms:
result = next_date_ms <= get_timestamp_ms()
logger.debug(
"SM-15M.is_due: next_date_ms=%d, now_ms=%d, result=%s",
next_date_ms,
get_timestamp_ms(),
result,
)
return result
# 退化至天精度
next_date = data.get("next_date", 0)
current = get_daystamp()
result = next_date <= current
logger.debug(
"SM-15M.is_due: next_date=%d, current=%d, result=%s",
next_date,
current,
result,
)
return result
@classmethod
def get_rating(cls, algodata):
data = algodata.get(cls.algo_name, cls.defaults.copy())
af = data.get("af", 0.0)
logger.debug("SM-15M.get_rating: af=%f", af)
return str(af)
@classmethod
def nextdate(cls, algodata) -> int:
data = algodata.get(cls.algo_name, cls.defaults.copy())
n = data.get("next_date", 0)
logger.debug("SM-15M.nextdate: %d", n)
return n