feat(config): add RSAConfig + named profiles (default-16gb, paper-16k, paper-headline-40k)

This commit is contained in:
transcrilive
2026-05-10 02:36:51 +02:00
parent e2237f788c
commit 08ae956986
2 changed files with 131 additions and 0 deletions

View File

@@ -0,0 +1,71 @@
"""RSAConfig dataclass + named profiles. Zero MLX dependencies — purely declarative."""
from __future__ import annotations
from dataclasses import dataclass, replace
from typing import Literal
@dataclass(frozen=True)
class RSAConfig:
"""Markovian RSA hyperparameters.
Defaults target the 16 GB practical profile (N=4, T=2, chunk=16K),
not the paper headline (N=16, chunk=40K). See named constructors for
paper-aligned profiles.
"""
rounds: int = 2
parallel: int = 4
aggregation_subsample: int = 4
chunk_tokens: int = 16384
tail_tokens: int = 4096
final_tokens: int | None = None
temperature: float = 1.0
top_p: float = 0.95
top_k: int = -1
seed: int | None = None
serial: bool = False
auto_serial: bool = True
memory_fraction: float = 0.80
max_context_tokens: int | None = None
aggregation_template: Literal["zaya_v1"] = "zaya_v1"
answer_selection: Literal["first_final_candidate"] = "first_final_candidate"
def __post_init__(self) -> None:
if self.chunk_tokens <= 0:
raise ValueError("chunk_tokens must be positive")
if self.tail_tokens <= 0:
raise ValueError("tail_tokens must be positive")
if self.parallel < 1:
raise ValueError("parallel must be >= 1")
if self.rounds < 1:
raise ValueError("rounds must be >= 1")
if self.aggregation_subsample > self.parallel:
raise ValueError(
f"aggregation_subsample ({self.aggregation_subsample}) "
f"must be <= parallel ({self.parallel})"
)
if not 0.0 < self.memory_fraction <= 1.0:
raise ValueError("memory_fraction must be in (0, 1]")
def effective_final_tokens(self) -> int:
return self.final_tokens if self.final_tokens is not None else self.chunk_tokens
@classmethod
def default_16gb(cls) -> "RSAConfig":
return cls(parallel=2, chunk_tokens=16384, aggregation_subsample=2)
@classmethod
def paper_16k(cls) -> "RSAConfig":
return cls(parallel=4, chunk_tokens=16384, top_p=1.0)
@classmethod
def paper_headline_40k(cls) -> "RSAConfig":
return cls(
parallel=16,
aggregation_subsample=4,
chunk_tokens=40960,
final_tokens=40960,
top_p=1.0,
)
def replace(self, **kwargs) -> "RSAConfig":
return replace(self, **kwargs)