This approach also seems to work.
I think there should also be an indication in the state of where placement of the group of same-type pieces currently being placed should continue. I'm not sure yet how ChatGPT deals with this in the same Python code it generated:
Code: Select all
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from itertools import combinations
from typing import Callable, Iterable, Sequence
from random import randrange
Mask = int
Perm = tuple[int, ...]
def bits(mask: Mask) -> list[int]:
"""Return set-bit indices in increasing order."""
out: list[int] = []
while mask:
lsb = mask & -mask
out.append(lsb.bit_length() - 1)
mask ^= lsb
return out
def bit_count(mask: Mask) -> int:
return mask.bit_count()
def mask_from_squares(squares: Iterable[int]) -> Mask:
m = 0
for s in squares:
m |= 1 << s
return m
def mask_key(mask: Mask) -> tuple[int, ...]:
"""Lexicographic order on square lists, not merely integer order."""
return tuple(bits(mask))
@dataclass(frozen=True)
class FiniteBoardAction:
"""A finite board together with a finite permutation group acting on squares.
The group is represented explicitly as permutations of range(num_squares).
The first permutation is expected to be the identity. The implementation does
not require a multiplication table; stabilizers are represented as tuples of
indices into self.perms.
For speed, the constructor precomputes the image of each singleton bit under
each permutation. Then applying a permutation to a mask is only a loop over
the set bits, OR'ing precomputed singleton images.
"""
num_squares: int
perms: tuple[Perm, ...]
def __post_init__(self) -> None:
if not self.perms:
raise ValueError("at least the identity permutation is required")
n = self.num_squares
for p in self.perms:
if len(p) != n or sorted(p) != list(range(n)):
raise ValueError("invalid permutation")
if self.perms[0] != tuple(range(n)):
raise ValueError("perms[0] must be the identity")
bit_image = tuple(
tuple(1 << p[s] for s in range(n))
for p in self.perms
)
object.__setattr__(self, "bit_image", bit_image)
@property
def full_group(self) -> tuple[int, ...]:
return tuple(range(len(self.perms)))
@lru_cache(maxsize=None)
def apply_mask(self, perm_index: int, mask: Mask) -> Mask:
img = self.bit_image[perm_index]
out = 0
x = mask
while x:
lsb = x & -x
s = lsb.bit_length() - 1
out |= img[s]
x ^= lsb
return out
@lru_cache(maxsize=None)
def mask_key_cached(self, mask: Mask) -> tuple[int, ...]:
return mask_key(mask)
@lru_cache(maxsize=None)
def canonical_mask(self, mask: Mask, subgroup: tuple[int, ...]) -> Mask:
return min((self.apply_mask(g, mask) for g in subgroup), key=self.mask_key_cached)
@lru_cache(maxsize=None)
def stabilizer_of_mask(self, mask: Mask, subgroup: tuple[int, ...]) -> tuple[int, ...]:
return tuple(g for g in subgroup if self.apply_mask(g, mask) == mask)
def transform_position(self, masks: tuple[Mask, ...], g: int) -> tuple[Mask, ...]:
return tuple(self.apply_mask(g, m) for m in masks)
def canonical_position(self, masks: tuple[Mask, ...]) -> tuple[Mask, ...]:
return min((self.transform_position(masks, g) for g in self.full_group),
key=lambda ms: tuple(self.mask_key_cached(m) for m in ms))
def chess_d4_action() -> FiniteBoardAction:
"""Return the standard D4 action on an 8x8 chessboard.
Square numbering is row-major: square = 8 * rank + file, with rank/file in
0..7. The particular orientation does not matter as long as rank() and
unrank() use the same convention.
"""
def sq(f: int, r: int) -> int:
return 8 * r + f
transforms: list[Callable[[int, int], tuple[int, int]]] = [
lambda f, r: (f, r), # identity
lambda f, r: (7 - r, f), # rotate 90
lambda f, r: (7 - f, 7 - r), # rotate 180
lambda f, r: (r, 7 - f), # rotate 270
lambda f, r: (7 - f, r), # mirror vertical axis
lambda f, r: (f, 7 - r), # mirror horizontal axis
lambda f, r: (r, f), # mirror main diagonal
lambda f, r: (7 - r, 7 - f), # mirror anti-diagonal
]
perms: list[Perm] = []
for tr in transforms:
p = [0] * 64
for r in range(8):
for f in range(8):
nf, nr = tr(f, r)
p[sq(f, r)] = sq(nf, nr)
perms.append(tuple(p))
return FiniteBoardAction(64, tuple(perms))
@dataclass(eq=False)
class SymmetryRanker:
"""Exact symmetry-reduced ranker/unranker for fixed material.
material is an ordered sequence of multiplicities, one entry per piece type.
Example: [1, 1, 2, 1] could mean WK, BK, two white rooks, one black bishop.
Kings are not special. Pawns are not special either; this class knows only
anonymous square-occupying types. Legality constraints such as adjacent kings,
check, side to move, castling, en passant, etc. are intentionally outside the
model.
The algorithm ranks orbits of positions under board.action.full_group.
A position is represented as tuple[Mask, ...], one mask per material group.
The masks must be pairwise disjoint and have the specified popcounts.
"""
action: FiniteBoardAction
material: tuple[int, ...]
def __init__(self, action: FiniteBoardAction, material: Sequence[int]):
self.action = action
self.material = tuple(material)
if any(m < 0 for m in self.material):
raise ValueError("negative multiplicity")
if sum(self.material) > action.num_squares:
raise ValueError("too many pieces for board")
self.all_squares_mask = (1 << action.num_squares) - 1
def validate(self, masks: tuple[Mask, ...]) -> None:
if len(masks) != len(self.material):
raise ValueError("wrong number of piece groups")
occ = 0
for i, (m, need) in enumerate(zip(masks, self.material)):
if m & ~self.all_squares_mask:
raise ValueError(f"piece group {i} has squares outside the board")
if bit_count(m) != need:
raise ValueError(f"piece group {i} has popcount {bit_count(m)}, expected {need}")
if occ & m:
raise ValueError("overlapping pieces")
occ |= m
@lru_cache(maxsize=None)
def _canonical_choices(self, subgroup: tuple[int, ...], available: Mask, count: int) -> tuple[Mask, ...]:
"""Canonical representatives of count-subsets of available squares under subgroup.
This is the exact but simple part of the implementation. It enumerates
combinations of currently available squares and keeps one representative
per subgroup-orbit. For high multiplicities this should be replaced by
the orbit-inventory DP discussed in the design notes.
"""
return tuple(row[0] for row in self._choice_rows(subgroup, available, count))
@lru_cache(maxsize=None)
def _choice_rows(self, subgroup: tuple[int, ...], available: Mask, count: int) -> tuple[tuple[Mask, tuple[int, ...]], ...]:
"""Precomputed rows (choice, stabilizer) for one recursive state.
rank(), unrank(), canonicalize(), and count() all need the stabilizer of
each admissible canonical choice. Keeping it here avoids recomputing it
in each caller.
"""
if count == 0:
return ((0, subgroup),)
sqs = bits(available)
seen: set[Mask] = set()
reps: list[Mask] = []
for comb in combinations(sqs, count):
m = mask_from_squares(comb)
if m in seen:
continue
orbit = tuple(self.action.apply_mask(g, m) for g in subgroup)
seen.update(orbit)
rep = min(orbit, key=self.action.mask_key_cached)
reps.append(rep)
reps.sort(key=self.action.mask_key_cached)
return tuple((rep, self.action.stabilizer_of_mask(rep, subgroup)) for rep in reps)
def precompute(self) -> int:
"""Force construction of the reachable count/choice tables.
This is useful before timing rank()/unrank(). It returns the total number
of symmetry-distinct placements, i.e. the same value as count().
"""
return self.count()
@lru_cache(maxsize=None)
def count(self, group_index: int = 0, occupied: Mask = 0,
subgroup: tuple[int, ...] | None = None) -> int:
"""Number of symmetry-distinct completions from the given recursive state."""
if subgroup is None:
subgroup = self.action.full_group
if group_index == len(self.material):
return 1
need = self.material[group_index]
available = self.all_squares_mask & ~occupied
total = 0
for choice, new_subgroup in self._choice_rows(subgroup, available, need):
total += self.count(group_index + 1, occupied | choice, new_subgroup)
return total
def canonicalize(self, masks: tuple[Mask, ...]) -> tuple[Mask, ...]:
"""Canonicalize using the same recursive rule as rank()/unrank().
A globally lexicographic representative of the full G-orbit need not be
canonical at every recursive stabilizer step. Therefore canonicalization
must be recursive, not just min(g * position) over the full group.
"""
self.validate(masks)
return self._canonicalize_from(0, 0, self.action.full_group, masks)
def _canonicalize_from(self, group_index: int, occupied: Mask,
subgroup: tuple[int, ...], masks: tuple[Mask, ...]) -> tuple[Mask, ...]:
if group_index == len(self.material):
return ()
current = masks[group_index]
available = self.all_squares_mask & ~occupied
orbit = {self.action.apply_mask(g, current) for g in subgroup}
best: tuple[Mask, ...] | None = None
for choice, new_subgroup in self._choice_rows(subgroup, available, self.material[group_index]):
if choice not in orbit:
continue
# Pick a subgroup element that maps current to this canonical choice,
# and apply it to all later groups before recursing.
for g in subgroup:
if self.action.apply_mask(g, current) == choice:
transformed = tuple(
masks[i] if i <= group_index else self.action.apply_mask(g, masks[i])
for i in range(len(masks))
)
tail = self._canonicalize_from(group_index + 1, occupied | choice,
new_subgroup, transformed)
candidate = (choice,) + tail
if best is None or tuple(mask_key(m) for m in candidate) < tuple(mask_key(m) for m in best):
best = candidate
break
if best is None:
raise ValueError("position is not representable in the current recursive state")
return best
def rank(self, masks: tuple[Mask, ...]) -> int:
"""Rank the symmetry orbit containing masks.
Non-canonical input is accepted: it is first canonicalized according to
the same recursive canonical path used by unrank(). The returned rank is
in 0..count()-1.
"""
masks = self.canonicalize(masks)
rank = 0
occupied = 0
subgroup = self.action.full_group
for gi, current in enumerate(masks):
need = self.material[gi]
available = self.all_squares_mask & ~occupied
if current & ~available or bit_count(current) != need:
raise ValueError("invalid canonical position")
choices = self._choice_rows(subgroup, available, need)
found = False
for choice, new_subgroup in choices:
block = self.count(gi + 1, occupied | choice, new_subgroup)
if choice == current:
occupied |= choice
subgroup = new_subgroup
found = True
break
rank += block
if not found:
raise ValueError("position is not represented by this ranker state")
return rank
def unrank(self, index: int) -> tuple[Mask, ...]:
"""Return the canonical representative of the position-orbit with this rank."""
total = self.count()
if not (0 <= index < total):
raise ValueError(f"index {index} outside range 0..{total - 1}")
masks: list[Mask] = []
occupied = 0
subgroup = self.action.full_group
for gi, need in enumerate(self.material):
available = self.all_squares_mask & ~occupied
for choice, new_subgroup in self._choice_rows(subgroup, available, need):
block = self.count(gi + 1, occupied | choice, new_subgroup)
if index >= block:
index -= block
continue
masks.append(choice)
occupied |= choice
subgroup = new_subgroup
break
else:
raise AssertionError("unrank failed to find a block")
return tuple(masks)
def squares_of(self, masks: tuple[Mask, ...]) -> list[list[int]]:
return [bits(m) for m in masks]
def demo() -> None:
action = chess_d4_action()
# Example material: three anonymous piece types with multiplicities 1, 1, 2.
# This could be WK, BK, two white knights, but the ranker treats all of them
# merely as typed pieces occupying squares.
ranker = SymmetryRanker(action, [1, 1, 2])
n = ranker.precompute();
print("number of symmetry-distinct placements:", n)
position = (
mask_from_squares([0]), # type 0
mask_from_squares([63]), # type 1
mask_from_squares([10, 20]), # type 2, two identical pieces
)
r = ranker.rank(position)
q = ranker.unrank(r)
print("rank:", r)
print("canonical squares:", ranker.squares_of(q))
assert q == ranker.canonicalize(position)
assert ranker.rank(q) == r
for i in range(10000):
r = randrange(n)
q = ranker.unrank(r)
assert ranker.rank(q) == r
print("OK")
if __name__ == "__main__":
demo()
It does not precompute full tables (but it did speed things up when I asked for it).
I added the test of 10000 random positions. Apparently it works.