-
-
Notifications
You must be signed in to change notification settings - Fork 30
Add no-repeat-ngram repetition guard to the constrained decoder #504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,52 @@ | ||||||||||
| import Foundation | ||||||||||
|
|
||||||||||
| /// File overview: | ||||||||||
|
Comment on lines
+1
to
+3
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||||||
| /// Pure no-repeat-ngram logic for the deterministic constrained decoder. Given the tokens generated | ||||||||||
| /// so far, it returns the token ids that must not be emitted next because doing so would repeat an | ||||||||||
| /// n-gram that already appeared in the output. | ||||||||||
| /// | ||||||||||
| /// Why this file exists: | ||||||||||
| /// The constrained decoder selects each token by raw-logit argmax. Greedy argmax has no inherent | ||||||||||
| /// resistance to repetition (the engine's `repetition_penalty` lives in its own sampler, which the | ||||||||||
| /// constrained path bypasses), so a base model can fall into a loop — "I think that I think that …" | ||||||||||
| /// or a single token emitted forever. A hard no-repeat-ngram block is the standard, deterministic | ||||||||||
| /// remedy: it forbids closing any (n)-gram that the output already contains. Keeping it pure makes | ||||||||||
| /// the rule exhaustively testable and keeps the decode loop a thin driver. | ||||||||||
| enum RepetitionGuard { | ||||||||||
| /// The token ids that would, if emitted next, repeat an `ngramSize`-gram already present in | ||||||||||
| /// `history`. A token `t` is blocked when the last `ngramSize - 1` tokens of `history` (the | ||||||||||
| /// pending prefix) already occur earlier in `history` immediately followed by `t`; emitting `t` | ||||||||||
| /// would reproduce that whole n-gram a second time. | ||||||||||
| /// | ||||||||||
| /// Returns an empty set when `ngramSize < 2` (a 1-gram block would forbid every token that ever | ||||||||||
| /// appeared, killing normal repetition like "the … the") or when `history` is too short to hold a | ||||||||||
| /// full prefix. Operates on token ids, not text, so it is independent of detokenization and works | ||||||||||
| /// the same for any vocabulary. | ||||||||||
| static func blockedTokens(history: [Int], ngramSize: Int) -> Set<Int> { | ||||||||||
| let prefixLength = ngramSize - 1 | ||||||||||
| guard ngramSize >= 2, history.count >= prefixLength else { | ||||||||||
| return [] | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // The pending prefix is the suffix of history that a next token would extend into an n-gram. | ||||||||||
| let prefix = Array(history.suffix(prefixLength)) | ||||||||||
|
|
||||||||||
| var blocked: Set<Int> = [] | ||||||||||
| // Every earlier position whose `prefixLength`-gram equals the pending prefix contributes the | ||||||||||
| // token that followed it: emitting that token now would repeat the n-gram. | ||||||||||
| var start = 0 | ||||||||||
| let lastPrefixStart = history.count - prefixLength | ||||||||||
| while start < lastPrefixStart { | ||||||||||
| var matches = true | ||||||||||
| for offset in 0 ..< prefixLength where history[start + offset] != prefix[offset] { | ||||||||||
| matches = false | ||||||||||
| break | ||||||||||
| } | ||||||||||
| if matches { | ||||||||||
| blocked.insert(history[start + prefixLength]) | ||||||||||
| } | ||||||||||
| start += 1 | ||||||||||
| } | ||||||||||
| return blocked | ||||||||||
| } | ||||||||||
| } | ||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| import XCTest | ||
| @testable import Cotabby | ||
|
|
||
| /// Pure tests for the no-repeat-ngram block set. Operates on token ids only, so cases are written as | ||
| /// small id sequences with the expected blocked followers. | ||
| final class RepetitionGuardTests: XCTestCase { | ||
|
|
||
| func test_ngramSizeBelowTwo_blocksNothing() { | ||
| // A 1-gram block would forbid every token that ever appeared; the guard refuses that. | ||
| XCTAssertEqual(RepetitionGuard.blockedTokens(history: [1, 1, 2], ngramSize: 1), []) | ||
| XCTAssertEqual(RepetitionGuard.blockedTokens(history: [1, 1, 2], ngramSize: 0), []) | ||
| } | ||
|
|
||
| func test_historyShorterThanPrefix_blocksNothing() { | ||
| // n=3 needs a 2-token pending prefix; one token cannot form it. | ||
| XCTAssertEqual(RepetitionGuard.blockedTokens(history: [7], ngramSize: 3), []) | ||
| } | ||
|
|
||
| func test_noRepeatedPrefix_blocksNothing() { | ||
| XCTAssertEqual(RepetitionGuard.blockedTokens(history: [1, 2, 3], ngramSize: 3), []) | ||
| } | ||
|
|
||
| func test_repeatedPrefix_blocksItsFollower() { | ||
| // Pending prefix [1,2] occurred earlier at index 0, followed by 1, so emitting 1 would repeat | ||
| // the trigram [1,2,1]. Block 1. | ||
| XCTAssertEqual(RepetitionGuard.blockedTokens(history: [1, 2, 1, 2], ngramSize: 3), [1]) | ||
| } | ||
|
|
||
| func test_singleTokenRun_blocksAfterThreeWithTrigram() { | ||
| // Three identical tokens are allowed; the fourth would repeat the trigram [5,5,5]. | ||
| XCTAssertEqual(RepetitionGuard.blockedTokens(history: [5, 5], ngramSize: 3), []) | ||
| XCTAssertEqual(RepetitionGuard.blockedTokens(history: [5, 5, 5], ngramSize: 3), [5]) | ||
| } | ||
|
|
||
| func test_multipleFollowers_allBlocked() { | ||
| // [1,2] appears twice, followed by 9 then 8; both followers are blocked. | ||
| let blocked = RepetitionGuard.blockedTokens(history: [1, 2, 9, 1, 2, 8, 1, 2], ngramSize: 3) | ||
| XCTAssertEqual(blocked, [9, 8]) | ||
| } | ||
|
|
||
| func test_bigramOrder_blocksRepeatedBigram() { | ||
| // n=2: pending prefix is the last single token. [1] occurred at index 0 followed by 2, so | ||
| // emitting 2 would repeat the bigram [1,2]. | ||
| XCTAssertEqual(RepetitionGuard.blockedTokens(history: [1, 2, 1], ngramSize: 2), [2]) | ||
| } | ||
|
|
||
| func test_prefixPresentButNotPending_notBlocked() { | ||
| // [1,2] appears early but the pending prefix is [3,4]; nothing repeats, so nothing is blocked. | ||
| XCTAssertEqual(RepetitionGuard.blockedTokens(history: [1, 2, 9, 3, 4], ngramSize: 3), []) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stopReasonwhen repetition guard exhausts all candidates"no_admissible_token"is already emitted when the byte-prefix constraint returns an empty admissible set; now the same string is logged when the repetition guard blocks every surviving candidate. A post-hoc log search won't distinguish between the two cases. Consider a distinct value such as"repetition_guard_exhausted"so decode diagnostics can tell apart a structural constraint failure from a repetition block.