ids2Labels,
final SentenceDetector sentenceDetector) {
+ Objects.requireNonNull(inferenceOptions, "inferenceOptions");
Objects.requireNonNull(ids2Labels, "ids2Labels");
Objects.requireNonNull(sentenceDetector, "sentenceDetector");
return inferenceOptions;
@@ -168,9 +178,16 @@ private static InferenceOptions validateConstructorArguments(
/**
* {@inheritDoc}
*
- * This method joins the provided tokens with spaces, sentence-splits the joined text,
- * runs each sentence through the ONNX token-classification model, decodes BIO labels into
- * {@link Span spans}, and resolves those spans back to character offsets in the joined text.
+ * Joins the provided tokens with spaces, sentence-splits the joined text, runs each sentence
+ * through the ONNX token-classification model, decodes BIO labels into {@link Span spans}, and
+ * resolves those spans to character offsets in the joined text after any optional input
+ * normalization.
+ *
+ * Note: this returns correct original offsets in every case except one. Whitespace folding is
+ * length-preserving, so it never moves offsets. Only dash folding can change the input length, and
+ * only for a non-BMP dash; so when {@code normalizeDashes} is enabled and the input contains a
+ * supplementary-plane dash, the returned spans are offsets into the normalized text rather than
+ * the original. For an exact original mapping in that case, use {@link #findInOriginal(String[])}.
*
* @throws IllegalStateException Thrown if inference fails, if the model output shape is not
* the expected {@code float[batch][token][label]} form, if the model output contains
@@ -178,44 +195,131 @@ private static InferenceOptions validateConstructorArguments(
* present in the configured label map.
* @throws IllegalArgumentException Thrown if a token produced for the input is not present in
* the vocabulary, which indicates the vocabulary file does not match the model.
+ * @throws NullPointerException Thrown if {@code input} is {@code null}.
*/
@Override
public Span[] find(String[] input) {
+ return locate(input).spans().toArray(new Span[0]);
+ }
- final List spans = new ArrayList<>();
+ /**
+ * Finds names and returns their {@link Span spans} in coordinates of the original joined input
+ * ({@code String.join(" ", input)}), regardless of any whitespace or dash normalization applied
+ * before inference. Spans are mapped back through the normalization {@link Alignment}, so a fold
+ * that changes the input length (a supplementary dash shrinking, or an expansion) does not shift
+ * the reported offsets. This implements {@link OffsetMappingNameFinder}, so an interface-typed
+ * caller can reach the offset-correct path with
+ * {@code finder instanceof OffsetMappingNameFinder}.
+ *
+ * @param input The tokens to search.
+ * @return The detected spans, in original-input character coordinates.
+ * @throws IllegalStateException Thrown under the same conditions as {@link #find(String[])}.
+ * @throws IllegalArgumentException Thrown under the same conditions as {@link #find(String[])}.
+ */
+ @Override
+ public Span[] findInOriginal(String[] input) {
+ final DecodedSpans decoded = locate(input);
+ final Alignment alignment = decoded.aligned().alignment();
+ final List mapped = new ArrayList<>(decoded.spans().size());
+ for (final Span span : decoded.spans()) {
+ final Span original = alignment.toOriginalSpan(span.getStart(), span.getEnd());
+ mapped.add(new Span(original.getStart(), original.getEnd(), span.getType(), span.getProb()));
+ }
+ return mapped.toArray(new Span[0]);
+ }
+
+ // Shared core: normalize the joined input (capturing the alignment back to the original), then
+ // decode each overlapping chunk bounded to its own character region and resolve overlaps. Bounding
+ // per chunk lets a boundary entity that two consecutive chunks both cover surface as overlapping
+ // candidates, which mergeOverlappingSpans collapses to the longer (more complete) span instead of
+ // silently keeping whichever a single forward cursor reached first.
+ private DecodedSpans locate(String[] input) {
+
+ Objects.requireNonNull(input, "input");
// Join the tokens here because they will be tokenized using Wordpiece during inference.
- final String text = String.join(" ", input);
+ final AlignedText normalized =
+ normalizeInputAligned(String.join(" ", input), normalizeWhitespace, normalizeDashes);
+ final String text = normalized.normalizedString();
// sentPosDetect (not sentDetect) so each sentence's offset in the full text is known.
final Span[] sentenceSpans = sentenceDetector.sentPosDetect(text);
+ final List candidates = new ArrayList<>();
for (final Span sentenceSpan : sentenceSpans) {
- // Floor the character cursor at this sentence's start, then thread it forward across the
- // sentence's chunks so a repeated surface form is located at its next occurrence. Flooring
- // per sentence keeps an entity from being matched against an identical surface form in an
- // earlier sentence -- even one that produced no spans, which would otherwise leave the
- // cursor behind and mis-locate the match.
- int searchStart = sentenceSpan.getStart();
-
- // The WordPiece tokenized text. This changes the spacing in the text.
- final List wordpieceTokens = tokenize(sentenceSpan.getCoveredText(text).toString());
-
- for (final Tokens tokens : wordpieceTokens) {
- final List decoded =
- decodeSpans(text, tokens.tokens(), infer(tokens), ids2Labels, searchStart,
- sentenceSpan.getEnd());
- spans.addAll(decoded);
- if (!decoded.isEmpty()) {
- searchStart = decoded.get(decoded.size() - 1).getEnd();
- }
+ final int sentenceStart = sentenceSpan.getStart();
+ final String sentence = sentenceSpan.getCoveredText(text).toString();
+
+ // The WordPiece tokenized text, in overlapping chunks. This changes the spacing in the text.
+ for (final ChunkTokens chunk : tokenize(sentence)) {
+ // Decode within the chunk's own character region in the full text. Keeping each chunk's
+ // entities inside the region it was built from locates a repeated surface form in the right
+ // chunk rather than mis-matching it to an earlier occurrence, while still letting two
+ // overlapping chunks both emit a boundary entity for mergeOverlappingSpans to reconcile.
+ final int regionStart = sentenceStart + chunk.start();
+ final int regionEnd = sentenceStart + chunk.end();
+ candidates.addAll(decodeSpans(text, chunk.tokens().tokens(), infer(chunk.tokens()),
+ ids2Labels, regionStart, regionEnd));
}
}
- return spans.toArray(new Span[0]);
+ return new DecodedSpans(mergeOverlappingSpans(candidates), normalized);
+ }
+
+ private record DecodedSpans(List spans, AlignedText aligned) {
+ }
+ // A chunk's WordPiece tokens paired with the chunk's half-open character span in the full text.
+ private record ChunkTokens(Tokens tokens, int start, int end) {
+ }
+
+ // Ordering for overlap resolution: longest span first, then higher probability. The dominant
+ // detection is kept and any later span overlapping it is dropped.
+ private static final Comparator BY_LENGTH_THEN_PROBABILITY =
+ Comparator.comparingInt(Span::length).reversed()
+ .thenComparing(Comparator.comparingDouble(Span::getProb).reversed());
+
+ /**
+ * Resolves spans that overlap in character coordinates, as happens when an entity falls in the
+ * shared region of two consecutive overlapping chunks and is decoded by both. The longer span is
+ * kept (the more complete decode) and ties break toward the higher probability; any span that
+ * overlaps an already kept one is dropped. Adjacent but disjoint spans are never merged, so
+ * neighbouring distinct entities and repeated surface forms at different offsets are preserved.
+ * The returned list is in document order. The choice is length-dominant rather than type-aware:
+ * when two overlapping spans carry different entity types, the longer still wins regardless of
+ * type, which is the intended heuristic for the rare cross-type overlap at a chunk boundary.
+ *
+ * @param spans The decoded candidate spans, in the order they were produced.
+ * @return The overlap-free spans, ordered by start offset.
+ */
+ static List mergeOverlappingSpans(final List spans) {
+ if (spans.size() < 2) {
+ // Return a fresh list so the caller always owns the result, matching the >= 2 path below
+ // (which returns a new list); the input is never handed back aliased.
+ return new ArrayList<>(spans);
+ }
+ final List byDominance = new ArrayList<>(spans);
+ byDominance.sort(BY_LENGTH_THEN_PROBABILITY);
+ // Kept spans never overlap each other, so they form a start-sorted partition. A candidate can
+ // only intersect the kept span that starts at or just before it (floor) or the next kept span
+ // that starts within it (ceiling); checking those two is O(log n) instead of scanning every kept
+ // span, making the whole longest-wins pass O(n log n) rather than O(n^2). Keyed by start, the map
+ // also yields the result already in document order.
+ final TreeMap kept = new TreeMap<>();
+ for (final Span candidate : byDominance) {
+ final Map.Entry before = kept.floorEntry(candidate.getStart());
+ if (before != null && before.getValue().getEnd() > candidate.getStart()) {
+ continue;
+ }
+ final Map.Entry after = kept.ceilingEntry(candidate.getStart());
+ if (after != null && after.getKey() < candidate.getEnd()) {
+ continue;
+ }
+ kept.put(candidate.getStart(), candidate);
+ }
+ return new ArrayList<>(kept.values());
}
/**
@@ -356,7 +460,7 @@ static List decodeSpans(String text, String[] tokens, float[][] tokenLabel
continue;
}
- final SpanMatch match = findByRegex(text, spanText, characterStart, searchEnd);
+ final SpanMatch match = findInSource(text, spanText, characterStart, searchEnd);
if (match.start() != -1) {
spans.add(new Span(match.start(), match.end(), entityType, entity.probability()));
characterStart = match.end();
@@ -487,9 +591,9 @@ static double labelProbability(float[] scores, int labelIndex) {
/**
* Reconstructs source-like text from a span of WordPiece tokens.
*
- * Special BERT tokens are skipped, {@code ##} continuations are merged into the preceding
- * surface form, and simple punctuation spacing is normalized so the result can be located in
- * the caller's original text.
+ * Special BERT and RoBERTa tokens are skipped, {@code ##} continuations are merged into the
+ * preceding surface form, and simple punctuation spacing is normalized so the result can be
+ * located in the caller's original text.
*
* @param tokens The WordPiece token sequence.
* @param startIndex The first token index to include.
@@ -503,7 +607,7 @@ static String buildSpanText(String[] tokens, int startIndex, int endIndex) {
for (int x = startIndex; x <= endIndex && x < tokens.length; x++) {
final String token = tokens[x];
- if (CLS_TOKEN.equals(token) || SEPARATOR.equals(token)) {
+ if (SPECIAL_TOKENS.contains(token)) {
continue;
}
@@ -567,35 +671,82 @@ private static int maxIndex(float[] arr) {
/**
* Locates reconstructed span text in a bounded region of the original input text.
*
+ * Matching is a single forward cursor scan, not a regular expression. Each space in the
+ * reconstructed span matches a run of zero or more Unicode whitespace characters in the source
+ * (so an entity whose WordPiece pieces were rejoined with spaces, such as {@code "AT & T"} for
+ * {@code "AT&T"}, is still located), and every other code point matches case-insensitively.
+ * Using a cursor avoids {@link java.util.regex.Pattern}/{@link java.util.regex.Matcher}
+ * allocation and the ReDoS surface of regular expressions, and recognizes Unicode whitespace
+ * that Java's {@code \s} does not.
+ *
* @param text The original text.
- * @param span The reconstructed span text.
+ * @param span The reconstructed span text, with sub-tokens separated by single ASCII spaces.
* @param searchStart The first character offset to search from.
* @param searchEnd The exclusive upper bound of the region to search.
* @return The matched character offsets, or {@code (-1, -1)} when the reconstructed text
* cannot be found in the requested region.
*/
- private static SpanMatch findByRegex(String text, String span, int searchStart, int searchEnd) {
-
- // Reconstructed span text normalizes whitespace, so match flexibly: a space in the span may
- // map to any run of whitespace OR none in the source (e.g. punctuation/'&' inside "U.S.A",
- // "AT&T" that wordpiece tokenization split apart). Use \s* rather than \s+ so such entities
- // are still located instead of being silently dropped.
- final String regex = Pattern.quote(span).replace(" ", "\\E\\s*\\Q");
+ private static SpanMatch findInSource(String text, String span, int searchStart, int searchEnd) {
- final Pattern pattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE);
- final Matcher matcher = pattern.matcher(text);
final int regionStart = Math.min(Math.max(searchStart, 0), text.length());
final int regionEnd = Math.min(Math.max(searchEnd, regionStart), text.length());
- matcher.region(regionStart, regionEnd);
- if (matcher.find()) {
- return new SpanMatch(matcher.start(), matcher.end());
+ int start = regionStart;
+ while (start < regionEnd) {
+ final int end = matchAt(text, span, start, regionEnd);
+ if (end != -1) {
+ return new SpanMatch(start, end);
+ }
+ start += Character.charCount(text.codePointAt(start));
}
return new SpanMatch(-1, -1);
}
+ /**
+ * Attempts to match {@code span} against {@code text} beginning at {@code start} and bounded by
+ * {@code regionEnd}. A space in {@code span} consumes a run of zero or more Unicode whitespace
+ * code points in the source; every other code point must match case-insensitively.
+ *
+ * @return The exclusive end offset of the match in {@code text}, or {@code -1} if no match
+ * begins at {@code start}.
+ */
+ private static int matchAt(String text, String span, int start, int regionEnd) {
+
+ int t = start;
+ int s = 0;
+
+ while (s < span.length()) {
+ final int spanCp = span.codePointAt(s);
+ if (spanCp == ' ') {
+ while (t < regionEnd && WHITESPACE.contains(text.codePointAt(t))) {
+ t += Character.charCount(text.codePointAt(t));
+ }
+ s += 1;
+ } else {
+ if (t >= regionEnd) {
+ return -1;
+ }
+ final int textCp = text.codePointAt(t);
+ if (!equalsIgnoreCase(spanCp, textCp)) {
+ return -1;
+ }
+ t += Character.charCount(textCp);
+ s += Character.charCount(spanCp);
+ }
+ }
+
+ return t;
+
+ }
+
+ private static boolean equalsIgnoreCase(int a, int b) {
+ return a == b
+ || Character.toLowerCase(a) == Character.toLowerCase(b)
+ || Character.toUpperCase(a) == Character.toUpperCase(b);
+ }
+
private record LabelPrediction(String label, double probability) {
}
@@ -609,24 +760,18 @@ private record EntityPrediction(int endIndex, double probability) {
private record SpanMatch(int start, int end) {
}
- private List tokenize(final String text) {
+ private List tokenize(final String text) {
- final List t = new LinkedList<>();
+ final List t = new LinkedList<>();
- // Segment long input text into overlapping chunks configured by InferenceOptions before
- // feeding each chunk into BERT.
+ // Segment long input text into overlapping chunks (split on Unicode whitespace) configured by
+ // InferenceOptions before feeding each chunk into BERT, keeping each chunk's character span so
+ // its decoded spans can be bounded to the region the chunk covers.
// https://medium.com/analytics-vidhya/text-classification-with-bert-using-transformers-for-long-text-inputs-f54833994dfd
- final String[] whitespaceTokenized = text.split("\\s+");
-
- for (ChunkRange chunkRange : chunkRanges(
- whitespaceTokenized.length, documentSplitSize, splitOverlapSize)) {
-
- // The group is that subsection of string.
- final String group = String.join(" ",
- Arrays.copyOfRange(whitespaceTokenized, chunkRange.start(), chunkRange.end()));
+ for (final TextChunk chunk : whitespaceChunkSpans(text, documentSplitSize, splitOverlapSize)) {
// Now we can tokenize the group and continue.
- final String[] tokens = tokenizer.tokenize(group);
+ final String[] tokens = tokenizer.tokenize(chunk.text());
final long[] ids = tokenIds(tokens, vocab);
@@ -636,7 +781,7 @@ private List tokenize(final String text) {
final long[] types = new long[ids.length];
Arrays.fill(types, 0);
- t.add(new Tokens(tokens, ids, mask, types));
+ t.add(new ChunkTokens(new Tokens(tokens, ids, mask, types), chunk.start(), chunk.end()));
}
diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/AbstractDLChunkingTest.java b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/AbstractDLChunkingTest.java
new file mode 100644
index 000000000..13957ea9b
--- /dev/null
+++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/AbstractDLChunkingTest.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package opennlp.dl;
+
+import java.util.List;
+
+import org.junit.jupiter.api.Test;
+
+import opennlp.tools.util.Span;
+import opennlp.tools.util.normalizer.AlignedText;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+/**
+ * Model-free tests for {@link AbstractDL#whitespaceChunks(String, int, int)} and
+ * {@link AbstractDL#whitespaceChunkSpans(String, int, int)}, the shared tokenize-and-chunk seam
+ * used by both {@code NameFinderDL} and {@code DocumentCategorizerDL}.
+ */
+public class AbstractDLChunkingTest {
+
+ @Test
+ void testSplitsOnUnicodeWhitespaceNotJustAscii() {
+ // A no-break space (U+00A0) and an ideographic space (U+3000) are not matched by Java's \s
+ // but must still separate tokens; the chunk is rejoined with single ASCII spaces.
+ final String nbsp = new String(Character.toChars(0x00A0));
+ final String ideographic = new String(Character.toChars(0x3000));
+ assertEquals(List.of("alpha beta gamma"),
+ AbstractDL.whitespaceChunks("alpha" + nbsp + "beta" + ideographic + "gamma", 100, 0));
+ }
+
+ @Test
+ void testDropsEmptyTokensFromLeadingTrailingAndRepeatedWhitespace() {
+ // Unlike split("\\s+"), the Unicode-aware split yields no empty leading or trailing tokens.
+ assertEquals(List.of("a b c"), AbstractDL.whitespaceChunks(" a b\tc ", 100, 0));
+ }
+
+ @Test
+ void testAppliesChunkSizeWithoutOverlap() {
+ assertEquals(List.of("a b", "c d"), AbstractDL.whitespaceChunks("a b c d", 2, 0));
+ }
+
+ @Test
+ void testAppliesChunkOverlap() {
+ assertEquals(List.of("a b", "b c", "c d"), AbstractDL.whitespaceChunks("a b c d", 2, 1));
+ }
+
+ @Test
+ void testEmptyTextYieldsNoChunks() {
+ assertEquals(List.of(), AbstractDL.whitespaceChunks("", 100, 0));
+ }
+
+ @Test
+ void testNormalizeInputIsOptInAndOffsetPreserving() {
+ final String nbsp = new String(Character.toChars(0x00A0));
+ final String emDash = new String(Character.toChars(0x2014));
+ final String input = "a" + nbsp + "b" + emDash + "c";
+
+ // Off by default: unchanged.
+ assertEquals(input, AbstractDL.normalizeInput(input, false, false));
+
+ // Whitespace only: the no-break space becomes a space, and the length is preserved.
+ final String ws = AbstractDL.normalizeInput(input, true, false);
+ assertEquals("a b" + emDash + "c", ws);
+ assertEquals(input.length(), ws.length());
+
+ // Dashes only: the em dash becomes an ASCII hyphen.
+ assertEquals("a" + nbsp + "b-c", AbstractDL.normalizeInput(input, false, true));
+
+ // Both.
+ assertEquals("a b-c", AbstractDL.normalizeInput(input, true, true));
+ }
+
+ @Test
+ void testNormalizeInputAlignedMapsSpanBackAcrossSupplementaryDash() {
+ // The aligned path used by NameFinderDL.findInOriginal: a span found in the folded text maps
+ // back to a tight original span even though the supplementary dash shrank by one unit.
+ final String yezidi = new String(Character.toChars(0x10EAD));
+ final String input = "a" + yezidi + "b"; // a(0), dash(1,2), b(3) -> length 4
+ final AlignedText at = AbstractDL.normalizeInputAligned(input, false, true);
+ assertEquals("a-b", at.normalized());
+
+ final Span b = at.toOriginalSpan(2, 3); // "b" in the folded text
+ assertEquals(3, b.getStart()); // maps back to original offset 3, not shifted to 2
+ assertEquals(4, b.getEnd());
+
+ final Span hyphen = at.toOriginalSpan(1, 2); // the folded hyphen covers the two-unit dash
+ assertEquals(1, hyphen.getStart());
+ assertEquals(3, hyphen.getEnd());
+ }
+
+ @Test
+ void testNormalizeInputAlignedIsIdentityWhenLengthPreserved() {
+ // Whitespace folding is length-preserving, so the alignment maps every position to itself.
+ final AlignedText at = AbstractDL.normalizeInputAligned("a\tb", true, false);
+ assertEquals("a b", at.normalized());
+ for (int i = 0; i < 3; i++) {
+ final Span s = at.toOriginalSpan(i, i + 1);
+ assertEquals(i, s.getStart());
+ assertEquals(i + 1, s.getEnd());
+ }
+ }
+
+ @Test
+ void testNormalizeInputAlignedComposesWhitespaceAndDash() {
+ // With both folds enabled the alignment is composed across stages, so a span found in the fully
+ // folded text still maps back to the original even though the supplementary dash shrank a unit.
+ final String nbsp = new String(Character.toChars(0x00A0));
+ final String yezidi = new String(Character.toChars(0x10EAD));
+ final String input = "a" + nbsp + "b" + yezidi + "c"; // a(0) nbsp(1) b(2) dash(3,4) c(5)
+ final AlignedText at = AbstractDL.normalizeInputAligned(input, true, true);
+ assertEquals("a b-c", at.normalized()); // a(0) space(1) b(2) hyphen(3) c(4)
+
+ final Span c = at.toOriginalSpan(4, 5); // "c" maps past the two-unit dash
+ assertEquals(5, c.getStart());
+ assertEquals(6, c.getEnd());
+
+ final Span hyphen = at.toOriginalSpan(3, 4); // folded hyphen covers the supplementary dash
+ assertEquals(3, hyphen.getStart());
+ assertEquals(5, hyphen.getEnd());
+
+ final Span space = at.toOriginalSpan(1, 2); // the no-break space, length preserved
+ assertEquals(1, space.getStart());
+ assertEquals(2, space.getEnd());
+ }
+
+ @Test
+ void testWhitespaceChunkSpansCarryCharacterOffsets() {
+ final String text = "a b c d"; // tokens a[0,1) b[2,3) c[4,5) d[6,7)
+ final List chunks = AbstractDL.whitespaceChunkSpans(text, 2, 1);
+
+ assertEquals(3, chunks.size());
+ assertEquals("a b", chunks.get(0).text());
+ assertEquals(0, chunks.get(0).start());
+ assertEquals(3, chunks.get(0).end());
+ assertEquals("b c", chunks.get(1).text());
+ assertEquals(2, chunks.get(1).start());
+ assertEquals(5, chunks.get(1).end());
+ assertEquals("c d", chunks.get(2).text());
+ assertEquals(4, chunks.get(2).start());
+ assertEquals(7, chunks.get(2).end());
+
+ // For single-ASCII-space input the reported span covers the chunk's rejoined text exactly.
+ for (final AbstractDL.TextChunk chunk : chunks) {
+ assertEquals(chunk.text(), text.substring(chunk.start(), chunk.end()));
+ }
+ }
+}
diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
index 80087d97c..3e8239114 100644
--- a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
+++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java
@@ -85,6 +85,25 @@ void testCategorizeRejectsMalformedInput() {
assertThrows(IllegalArgumentException.class, () -> categorizer.categorize(new String[0]));
}
+ @Test
+ void testCategorizeRejectsTokenlessContent() {
+ // A non-empty array whose document has no tokens (empty or only whitespace, including Unicode
+ // whitespace) is rejected up front rather than crashing downstream on an empty score list.
+ final DocumentCategorizerDL categorizer = categorizerWithoutSession();
+ final String nbsp = new String(Character.toChars(0x00A0));
+
+ assertThrows(IllegalArgumentException.class, () -> categorizer.categorize(new String[] {""}));
+ assertThrows(IllegalArgumentException.class, () -> categorizer.categorize(new String[] {" "}));
+ assertThrows(IllegalArgumentException.class, () -> categorizer.categorize(new String[] {nbsp}));
+ }
+
+ @Test
+ void testConstructorRejectsNullInferenceOptions() {
+ assertThrows(NullPointerException.class, () ->
+ new DocumentCategorizerDL(null, null, vocab(), categories(),
+ new AverageClassificationScoringStrategy(), null));
+ }
+
@Test
void testTokenIdsMapsTokensToVocabularyIds() {
final long[] ids = DocumentCategorizerDL.tokenIds(
@@ -103,6 +122,27 @@ void testTokenIdsRejectsTokensMissingFromVocabulary() {
"the error message should name the missing token: " + e.getMessage());
}
+ @Test
+ void testSoftmaxRejectsNaNLogit() {
+ // A NaN logit would otherwise poison the whole distribution into NaN scores; fail loudly instead.
+ final IllegalStateException e = assertThrows(IllegalStateException.class, () ->
+ DocumentCategorizerDL.softmax(new float[] {0f, Float.NaN, 0f}));
+ assertTrue(e.getMessage().contains("NaN"), e.getMessage());
+ }
+
+ @Test
+ void testSoftmaxRejectsInfiniteLogit() {
+ // A +Infinity logit (not NaN, so it slips past an isNaN-only guard) poisons the distribution too:
+ // max becomes +Inf, so value - max is Inf - Inf == NaN, every exp() is NaN, and categorize() would
+ // silently return all-NaN scores. It must fail loud like the NaN case. -Infinity is non-finite too.
+ final IllegalStateException pos = assertThrows(IllegalStateException.class, () ->
+ DocumentCategorizerDL.softmax(new float[] {0f, Float.POSITIVE_INFINITY, 0f}));
+ assertTrue(pos.getMessage().contains("non-finite") || pos.getMessage().contains("Infinity"),
+ pos.getMessage());
+ assertThrows(IllegalStateException.class, () ->
+ DocumentCategorizerDL.softmax(new float[] {0f, Float.NEGATIVE_INFINITY, 0f}));
+ }
+
@Test
void testSoftmaxIsUniformForEqualLogitsAndSumsToOne() {
final double[] out = DocumentCategorizerDL.softmax(new float[] {0f, 0f, 0f});
@@ -146,4 +186,34 @@ void testSoftmaxMatchesReferenceDistribution() {
assertEquals(0.24472847, out[1], 1e-6);
assertEquals(0.66524096, out[2], 1e-6);
}
+
+ @Test
+ void testLogitsFromOutputDispatchesOnModelShape() {
+ // A 2D output (BERT-style) takes row 0; a 1D output (RoBERTa-style) is taken as-is.
+ assertArrayEquals(new float[] {1f, 2f},
+ DocumentCategorizerDL.logitsFromOutput(new float[][] {{1f, 2f}}));
+ assertArrayEquals(new float[] {3f, 4f},
+ DocumentCategorizerDL.logitsFromOutput(new float[] {3f, 4f}));
+ }
+
+ @Test
+ void testLogitsFromOutputFailsLoudlyOnNullAndUnexpectedType() {
+ // A null or otherwise-shaped model output is a contract violation, not an "inference failed".
+ final IllegalStateException onNull = assertThrows(IllegalStateException.class,
+ () -> DocumentCategorizerDL.logitsFromOutput(null));
+ assertTrue(onNull.getMessage().contains("null"), onNull.getMessage());
+ assertThrows(IllegalStateException.class,
+ () -> DocumentCategorizerDL.logitsFromOutput("not a tensor"));
+ }
+
+ @Test
+ void testRequireMatchingCategoryCountFailsLoudlyOnMismatch() {
+ // A distribution whose length differs from the configured category count means the model and
+ // the categorizer configuration do not match; the matching case passes the array through.
+ final double[] ok = {0.5, 0.5};
+ assertArrayEquals(ok, DocumentCategorizerDL.requireMatchingCategoryCount(ok, 2));
+ final IllegalStateException e = assertThrows(IllegalStateException.class,
+ () -> DocumentCategorizerDL.requireMatchingCategoryCount(new double[] {1.0}, 2));
+ assertTrue(e.getMessage().contains("do not match"), e.getMessage());
+ }
}
diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLTest.java b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLTest.java
index c0a8aede2..4bf698990 100644
--- a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLTest.java
+++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLTest.java
@@ -17,6 +17,7 @@
package opennlp.dl.namefinder;
+import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -28,6 +29,7 @@
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -169,7 +171,7 @@ void testDecodeSpansSearchStartLocatesNextOccurrence() {
void testDecodeSpansLocatesEntityWithInternalPunctuation() {
// WordPiece splits "AT&T" into separate AT / & / T tokens, so the reconstructed span text
// ("AT & T") must still be located in the contiguous source. Regression guard for the
- // flexible-whitespace (\s*) matching in findByRegex.
+ // flexible-whitespace matching in findInSource (a span space matches zero source whitespace).
final String text = "Buy AT&T stock";
final String[] tokens = {"[CLS]", "Buy", "AT", "&", "T", "stock", "[SEP]"};
final float[][] scores = {
@@ -184,6 +186,37 @@ void testDecodeSpansLocatesEntityWithInternalPunctuation() {
assertEquals("AT&T", spans.get(0).getCoveredText(text));
}
+ @Test
+ void testDecodeSpansMatchesEntitySeparatedByNoBreakSpace() {
+ // The source separates "New" and "York" with a no-break space (U+00A0). Java's \s does not
+ // match it, so the previous regex matcher would have dropped this LOC span; the Unicode-aware
+ // cursor matcher locates it and the covered text includes the no-break space.
+ final String nbsp = new String(Character.toChars(0x00A0));
+ final String text = "Visit New" + nbsp + "York today";
+ final String[] tokens = {"[CLS]", "New", "York", "[SEP]"};
+ final float[][] scores = {scoresFor(0), scoresFor(3), scoresFor(4), scoresFor(0)};
+
+ final List spans = NameFinderDL.decodeSpans(text, tokens, scores, ID_TO_LABELS);
+
+ assertEquals(1, spans.size());
+ assertEquals("LOC", spans.get(0).getType());
+ assertEquals("New" + nbsp + "York", spans.get(0).getCoveredText(text));
+ }
+
+ @Test
+ void testDecodeSpansMatchesEntitySeparatedByIdeographicSpace() {
+ // Same idea with the CJK ideographic space (U+3000), another character outside Java's \s.
+ final String ideographic = new String(Character.toChars(0x3000));
+ final String text = "from New" + ideographic + "York city";
+ final String[] tokens = {"[CLS]", "New", "York", "[SEP]"};
+ final float[][] scores = {scoresFor(0), scoresFor(3), scoresFor(4), scoresFor(0)};
+
+ final List spans = NameFinderDL.decodeSpans(text, tokens, scores, ID_TO_LABELS);
+
+ assertEquals(1, spans.size());
+ assertEquals("New" + ideographic + "York", spans.get(0).getCoveredText(text));
+ }
+
@Test
void testDecodeSpansDoesNotMatchBeyondSearchEnd() {
final String text = "London was quiet. Later Paris was loud.";
@@ -199,7 +232,7 @@ void testDecodeSpansDoesNotMatchBeyondSearchEnd() {
@Test
void testDecodeSpansMatchesSourceCaseInsensitively() {
// The reconstructed span text may differ in case from the source (e.g. an uncased model);
- // findByRegex matches case-insensitively, so the span is still located at the source offsets.
+ // findInSource matches case-insensitively, so the span is still located at the source offsets.
final String text = "Visit PARIS today";
final String[] tokens = {"[CLS]", "Visit", "paris", "today", "[SEP]"};
final float[][] scores = {
@@ -324,6 +357,72 @@ void testDecodeSpansEmitsRepeatedEntityAtDistinctOffsets() {
assertEquals(15, spans.get(1).getEnd());
}
+ @Test
+ void testMergeOverlappingSpansKeepsLongestAndPreservesDisjoint() {
+ // Containment: the longer span absorbs the shorter overlapping one.
+ final List contained = NameFinderDL.mergeOverlappingSpans(new ArrayList<>(List.of(
+ new Span(0, 8, "LOC", 0.9), new Span(0, 13, "LOC", 0.8))));
+ assertEquals(1, contained.size());
+ assertEquals(0, contained.get(0).getStart());
+ assertEquals(13, contained.get(0).getEnd());
+
+ // Partial overlap: the longer span wins, the shorter overlapping one is dropped.
+ final List partial = NameFinderDL.mergeOverlappingSpans(new ArrayList<>(List.of(
+ new Span(0, 5, "LOC", 0.7), new Span(3, 12, "LOC", 0.6))));
+ assertEquals(1, partial.size());
+ assertEquals(3, partial.get(0).getStart());
+ assertEquals(12, partial.get(0).getEnd());
+
+ // Equal length overlap: the higher probability wins the tie.
+ final List tie = NameFinderDL.mergeOverlappingSpans(new ArrayList<>(List.of(
+ new Span(0, 5, "PER", 0.6), new Span(2, 7, "PER", 0.9))));
+ assertEquals(1, tie.size());
+ assertEquals(2, tie.get(0).getStart());
+ assertEquals(7, tie.get(0).getEnd());
+
+ // Adjacent but disjoint spans are both kept and returned in document order.
+ final List disjoint = NameFinderDL.mergeOverlappingSpans(new ArrayList<>(List.of(
+ new Span(5, 10, "LOC", 0.9), new Span(0, 5, "PER", 0.9))));
+ assertEquals(2, disjoint.size());
+ assertEquals(0, disjoint.get(0).getStart());
+ assertEquals(5, disjoint.get(1).getStart());
+ }
+
+ @Test
+ void testMergeOverlappingSpansReturnsAFreshListForTrivialInput() {
+ // The size < 2 fast path must hand back a new list, not the caller's, so the result is always
+ // owned by the caller -- the same ownership contract as the merging path.
+ final List single = new ArrayList<>(List.of(new Span(0, 5, "PER", 0.9)));
+ final List merged = NameFinderDL.mergeOverlappingSpans(single);
+ assertEquals(1, merged.size());
+ assertNotSame(single, merged);
+ assertTrue(NameFinderDL.mergeOverlappingSpans(new ArrayList()).isEmpty());
+ }
+
+ @Test
+ void testChunkAssemblyKeepsFullerOverlappingSpan() {
+ // Mirrors how locate() decodes two overlapping chunks bounded to their own character regions.
+ // Chunk 1 covers up to "York" and labels "New York"; the overlapping chunk 2 covers "New York
+ // City" and labels it. Both candidates are produced and mergeOverlappingSpans keeps the fuller
+ // "New York City" rather than dropping it, which is the chunk-boundary case a single forward
+ // cursor mishandled.
+ final String text = "Alice visited New York City."; // "New York" = [14,22), "City" ends at 27
+ final String[] chunk1 = {"[CLS]", "New", "York", "[SEP]"};
+ final float[][] scores1 = {scoresFor(0), scoresFor(3), scoresFor(4), scoresFor(0)};
+ final String[] chunk2 = {"[CLS]", "New", "York", "City", "[SEP]"};
+ final float[][] scores2 =
+ {scoresFor(0), scoresFor(3), scoresFor(4), scoresFor(4), scoresFor(0)};
+
+ final List candidates = new ArrayList<>();
+ candidates.addAll(NameFinderDL.decodeSpans(text, chunk1, scores1, ID_TO_LABELS, 14, 22));
+ candidates.addAll(NameFinderDL.decodeSpans(text, chunk2, scores2, ID_TO_LABELS, 14, 27));
+ assertEquals(2, candidates.size(), "both chunks emit a candidate for the boundary entity");
+
+ final List merged = NameFinderDL.mergeOverlappingSpans(candidates);
+ assertEquals(1, merged.size());
+ assertEquals("New York City", merged.get(0).getCoveredText(text));
+ }
+
@Test
void testDecodeSpansLocatesEntityWithRegexMetacharacters() {
// WordPiece splits "C++" into C / + / + tokens, so the reconstructed span text contains regex
@@ -390,6 +489,14 @@ void testLabelProbabilityIgnoresNaNInDenominator() {
assertBounded(p);
}
+ @Test
+ void testBuildSpanTextSkipsRobertaSpecialTokens() {
+ // RoBERTa markers (, ) must be skipped during span reconstruction, the same way the
+ // BERT [CLS]/[SEP] markers are, so they never leak into a reconstructed entity span.
+ assertEquals("New York",
+ NameFinderDL.buildSpanText(new String[] {"", "New", "York", ""}, 0, 3));
+ }
+
private static float[] scoresFor(int labelIndex) {
final float[] scores = new float[ID_TO_LABELS.size()];
for (int i = 0; i < scores.length; i++) {
diff --git a/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java b/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
index bf7002a06..5778b7438 100644
--- a/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
+++ b/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java
@@ -239,7 +239,7 @@ public void categorizeWithGpu() throws Exception {
try (final DocumentCategorizerDL documentCategorizerDL =
new DocumentCategorizerDL(model, vocab, getCategories(),
new AverageClassificationScoringStrategy(),
- new InferenceOptions())) {
+ inferenceOptions)) {
final double[] result = documentCategorizerDL.categorize(new String[] {"I am happy"});
logger.debug(Arrays.toString(result));
diff --git a/opennlp-eval-tests/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java b/opennlp-eval-tests/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
index e19742b96..f06d03a2b 100644
--- a/opennlp-eval-tests/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
+++ b/opennlp-eval-tests/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java
@@ -407,6 +407,157 @@ public void invalidModel() {
}
+ /**
+ * End-to-end offset safety: with dash normalization enabled and a non-BMP dash before an entity,
+ * the fold shrinks the text by one UTF-16 unit, so the entity sits at a smaller offset in the
+ * normalized text than in the original. {@link NameFinderDL#findInOriginal(String[])} must report
+ * the entity at its true offset in the original input, not the one-unit-shorter normalized offset.
+ */
+ @Test
+ public void findInOriginalMapsSpansAcrossNonBmpDash() throws Exception {
+
+ final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx");
+ final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt");
+
+ final InferenceOptions options = new InferenceOptions();
+ options.setNormalizeDashes(true);
+
+ // Yezidi hyphen (U+10EAD): a Unicode Dash code point in the supplementary plane (two UTF-16
+ // units) that folds to a single ASCII hyphen, shifting every following character left by one.
+ final String yezidi = new String(Character.toChars(0x10EAD));
+ final String[] tokens = {yezidi, "George", "Washington", "was", "president",
+ "of", "the", "United", "States", "."};
+ final String original = String.join(" ", tokens);
+
+ try (final NameFinderDL nameFinderDL =
+ new NameFinderDL(model, vocab, getIds2Labels(), options, sentenceDetector)) {
+
+ final Span[] spans = nameFinderDL.findInOriginal(tokens);
+
+ Span person = null;
+ for (final Span span : spans) {
+ if ("PER".equals(span.getType())) {
+ person = span;
+ }
+ }
+ Assertions.assertNotNull(person, "the PER entity should still be detected after the dash");
+ // Mapped back through the alignment, the span covers the entity in the ORIGINAL input (which
+ // still contains the two-unit dash); without the mapping it would be shifted left by one.
+ Assertions.assertEquals("George Washington", person.getCoveredText(original));
+ Assertions.assertEquals(original.indexOf("George Washington"), person.getStart());
+ }
+
+ }
+
+ /**
+ * End-to-end chunk-boundary safety with the real model. With a small split size and an overlap,
+ * "United States" straddles a chunk boundary: the first chunk ends at "United" while the
+ * overlapping next chunk sees the full entity. The decoder must keep the complete "United States"
+ * span rather than a truncated copy from whichever chunk decoded first, and the output spans must
+ * not overlap.
+ */
+ @Test
+ public void findKeepsFullEntityAcrossChunkBoundary() throws Exception {
+
+ final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx");
+ final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt");
+
+ final InferenceOptions options = new InferenceOptions();
+ // chunk 1 = tokens [0,7) (ends at "United"); chunk 2 = tokens [5,9) ("the United States .").
+ options.setDocumentSplitSize(7);
+ options.setSplitOverlapSize(2);
+
+ final String[] tokens = {"George", "Washington", "was", "president",
+ "of", "the", "United", "States", "."};
+ final String text = String.join(" ", tokens);
+
+ try (final NameFinderDL nameFinderDL =
+ new NameFinderDL(model, vocab, getIds2Labels(), options, sentenceDetector)) {
+
+ final Span[] spans = nameFinderDL.find(tokens);
+
+ boolean fullLocation = false;
+ boolean person = false;
+ for (final Span span : spans) {
+ if ("LOC".equals(span.getType()) && "United States".equals(span.getCoveredText(text))) {
+ fullLocation = true;
+ }
+ if ("PER".equals(span.getType()) && "George Washington".equals(span.getCoveredText(text))) {
+ person = true;
+ }
+ }
+ Assertions.assertTrue(fullLocation,
+ "the full 'United States' entity must survive the chunk boundary, not a truncated copy");
+ Assertions.assertTrue(person, "the PER entity should still be detected");
+
+ assertNoOverlappingSpans(spans, text);
+ }
+
+ }
+
+ /**
+ * End-to-end de-duplication with the real model. With a generous overlap, "United States" sits
+ * entirely within the shared region of two chunks and is decoded by both. The decoder must report
+ * it once, not twice.
+ */
+ @Test
+ public void findDeduplicatesEntityDecodedInBothOverlappingChunks() throws Exception {
+
+ final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx");
+ final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt");
+
+ final InferenceOptions options = new InferenceOptions();
+ // chunk 1 = tokens [0,8); chunk 2 = tokens [4,9); "United States" (tokens 6,7) is in both.
+ options.setDocumentSplitSize(8);
+ options.setSplitOverlapSize(4);
+
+ final String[] tokens = {"George", "Washington", "was", "president",
+ "of", "the", "United", "States", "."};
+ final String text = String.join(" ", tokens);
+
+ try (final NameFinderDL nameFinderDL =
+ new NameFinderDL(model, vocab, getIds2Labels(), options, sentenceDetector)) {
+
+ final Span[] spans = nameFinderDL.find(tokens);
+
+ int locations = 0;
+ for (final Span span : spans) {
+ if ("LOC".equals(span.getType()) && "United States".equals(span.getCoveredText(text))) {
+ locations++;
+ }
+ }
+ Assertions.assertEquals(1, locations,
+ "the entity in the chunk overlap must be reported once, not duplicated");
+
+ assertNoOverlappingSpans(spans, text);
+ }
+
+ }
+
+ // Asserts no two spans overlap in character coordinates, the post-merge invariant of locate().
+ private static void assertNoOverlappingSpans(final Span[] spans, final String text) {
+ for (int i = 0; i < spans.length; i++) {
+ for (int j = i + 1; j < spans.length; j++) {
+ Assertions.assertFalse(spans[i].intersects(spans[j]),
+ "spans must not overlap after merge: " + spans[i].getCoveredText(text)
+ + " vs " + spans[j].getCoveredText(text));
+ }
+ }
+ }
+
+ @Test
+ public void findRejectsNullInput() throws Exception {
+ // Public entry points fail fast on a null token array rather than deeper inside String.join.
+ final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx");
+ final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt");
+
+ try (final NameFinderDL nameFinderDL =
+ new NameFinderDL(model, vocab, getIds2Labels(), sentenceDetector)) {
+ Assertions.assertThrows(NullPointerException.class, () -> nameFinderDL.find(null));
+ Assertions.assertThrows(NullPointerException.class, () -> nameFinderDL.findInOriginal(null));
+ }
+ }
+
private Map getIds2Labels() {
final Map ids2Labels = new HashMap<>();