diff --git a/opennlp-api/src/main/java/opennlp/tools/namefind/OffsetMappingNameFinder.java b/opennlp-api/src/main/java/opennlp/tools/namefind/OffsetMappingNameFinder.java new file mode 100644 index 000000000..d9a0dc08b --- /dev/null +++ b/opennlp-api/src/main/java/opennlp/tools/namefind/OffsetMappingNameFinder.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package opennlp.tools.namefind; + +import opennlp.tools.util.Span; + +/** + * A {@link TokenNameFinder} that can additionally report detected spans in the character coordinates + * of the original input, mapping back through any text normalization applied before detection. + * + *

An implementation that normalizes input before detection (for example an ONNX model that folds + * Unicode whitespace or dashes) returns spans from {@link #find(String[])} in the coordinates of the + * normalized text, which no longer line up with the caller's input when a fold changes the length. + * {@link #findInOriginal(String[])} maps those spans back to original-input coordinates. This is a + * separate capability interface rather than a method on {@link TokenNameFinder} because the classic + * contract reports token-index spans, for which an original-character mapping is not meaningful; an + * interface-typed caller tests for the capability ({@code finder instanceof OffsetMappingNameFinder}) + * instead of depending on a concrete implementation.

+ */ +public interface OffsetMappingNameFinder extends TokenNameFinder { + + /** + * Finds names and returns their {@link Span spans} in the character coordinates of the original + * input, regardless of any normalization applied before detection. + * + * @param tokens The tokens to search. + * @return The detected spans, in original-input character coordinates. + */ + Span[] findInOriginal(String[] tokens); +} diff --git a/opennlp-core/opennlp-ml/opennlp-dl/README.md b/opennlp-core/opennlp-ml/opennlp-dl/README.md index 04a7715d4..ff6a72d40 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/README.md +++ b/opennlp-core/opennlp-ml/opennlp-dl/README.md @@ -22,6 +22,37 @@ Named entity models are commonly cased, so lower casing is disabled by default. Set `InferenceOptions#setLowerCase(true)` only for models trained with uncased input. +### Unicode text handling + +Long input is split into overlapping chunks on the full Unicode `White_Space` +set (not Java's `\s`), so no-break space, ideographic space, and the other UCD +whitespace characters are recognized as delimiters. `NameFinderDL` locates +reconstructed entity text in the original input with a cursor-based matcher that +treats span spaces as flexible Unicode whitespace and compares other code points +case-insensitively, so `Span#getCoveredText(...)` works on text from PDFs, the +web, and multilingual sources. + +Optional input folding is off by default and controlled through +`InferenceOptions`: + +```java +InferenceOptions options = new InferenceOptions(); +options.setNormalizeWhitespace(true); // each Unicode whitespace -> ASCII space (offset-preserving) +options.setNormalizeDashes(true); // Unicode dashes -> hyphen-minus (offset note below) +NameFinderDL finder = new NameFinderDL(model, vocab, ids2Labels, options, sentenceDetector); +``` + +Whitespace folding is length-preserving, so it never moves offsets. Dash folding can shrink a +non-BMP dash by one UTF-16 unit, but `NameFinderDL.findInOriginal` maps decoded spans back through +the normalization `Alignment`, so reported spans stay correct in the original input even for +non-BMP dashes. (`NameFinderDL.find` returns normalized-text offsets, which differ from the +original only in that non-BMP-dash case.) + +The same options apply to `DocumentCategorizerDL`. The underlying +`CharClass` / `CodePointSet` engine and the broader normalization pipeline live +in `opennlp.tools.util.normalizer` and are documented in the OpenNLP manual +chapter *Text Normalization*. + Export a Hugging Face NER model to ONNX, e.g.: ```bash @@ -30,12 +61,31 @@ python -m transformers.onnx --model=dslim/bert-base-NER --feature token-classifi ## DocumentCategorizerDL +Uses the same Unicode whitespace chunking and optional `InferenceOptions` +normalization as `NameFinderDL` (see above). + Export a Huggingface classification (e.g. sentiment) model to ONNX, e.g.: ```bash python -m transformers.onnx --model=nlptown/bert-base-multilingual-uncased-sentiment --feature sequence-classification exported ``` +## Behavior changes in this release + +Integrators upgrading from an earlier `opennlp-dl` should note these intentional changes (OPENNLP-1850): + +- `NameFinderDL.find(...)` reports spans in the coordinates of the joined input it ran inference on, + which differ from the original text only when length-changing dash folding is enabled. Use the new + `NameFinderDL.findInOriginal(...)` (from `OffsetMappingNameFinder`) for original-text coordinates. +- Spans that overlap at chunk boundaries are now merged longest-wins; `find(...)` previously returned + every decoded span, overlaps included. +- Chunking splits on the Unicode `White_Space` set rather than `String#split("\\s+")`, and + whitespace-only input now yields no spans without running the model. +- `DocumentCategorizerDL.categorize(...)` now rejects `null`/empty input, and a document with no + non-whitespace token, with `IllegalArgumentException` rather than running the model on empty input. +- The example label constants `NameFinderDL.I_PER` and `NameFinderDL.B_PER` were removed; supply your + own label strings (any `B-`/`I-` pair works, as described above). + ## SentenceVectors Convert a sentence vectors model to ONNX, e.g.: diff --git a/opennlp-core/opennlp-ml/opennlp-dl/pom.xml b/opennlp-core/opennlp-ml/opennlp-dl/pom.xml index 76d27f6fa..55d887ca3 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/pom.xml +++ b/opennlp-core/opennlp-ml/opennlp-dl/pom.xml @@ -37,6 +37,14 @@ org.apache.opennlp opennlp-api + + + org.apache.opennlp + opennlp-runtime + test + @@ -45,13 +53,6 @@ ${onnxruntime.version} - - - org.apache.opennlp - opennlp-runtime - test - - org.slf4j slf4j-api diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java index fe7cc99d5..561ed2d37 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java @@ -40,6 +40,10 @@ import opennlp.tools.tokenize.BertTokenizer; import opennlp.tools.tokenize.Tokenizer; import opennlp.tools.tokenize.WordpieceTokenizer; +import opennlp.tools.util.Span; +import opennlp.tools.util.normalizer.AlignedText; +import opennlp.tools.util.normalizer.Alignment; +import opennlp.tools.util.normalizer.CharClass; /** * Base class for OpenNLP deep-learning classes using ONNX Runtime. @@ -60,6 +64,17 @@ public abstract class AbstractDL implements AutoCloseable { protected record ChunkRange(int start, int end) { } + /** + * A rejoined chunk paired with its half-open character span in the text it was split from, so a + * chunk's decoded entities can be located within the region the chunk actually covers. + * + * @param text The chunk text, the chunk's whitespace tokens rejoined with single ASCII spaces. + * @param start The inclusive character offset of the chunk in the source text. + * @param end The exclusive character offset of the chunk in the source text. + */ + protected record TextChunk(String text, int start, int end) { + } + private static final Pattern JSON_ENTRY_PATTERN = Pattern.compile("\"((?:[^\"\\\\]|\\\\.)*)\"\\s*:\\s*(\\d+)"); @@ -316,15 +331,146 @@ protected static void validateSplitOptions(final InferenceOptions options) { */ protected static void validateSplitOptions(final int documentSplitSize, final int splitOverlapSize) { if (documentSplitSize <= 0) { - throw new IllegalArgumentException("documentSplitSize must be greater than zero."); + throw new IllegalArgumentException("The documentSplitSize must be greater than zero."); } if (splitOverlapSize < 0) { - throw new IllegalArgumentException("splitOverlapSize must not be negative."); + throw new IllegalArgumentException("The splitOverlapSize must not be negative."); } if (splitOverlapSize >= documentSplitSize) { throw new IllegalArgumentException( - "splitOverlapSize must be smaller than documentSplitSize."); + "The splitOverlapSize must be smaller than documentSplitSize."); + } + } + + /** + * Unicode-aware whitespace. Input is tokenized on the full Unicode {@code White_Space} set + * rather than the six ASCII characters Java's {@code \s} recognizes, and the same class is + * reused by subclasses that need to match against whitespace in the source text. + */ + protected static final CharClass WHITESPACE = CharClass.whitespace(); + + /** Unicode dashes (excluding the mathematical minus signs), used for optional input folding. */ + protected static final CharClass DASHES = CharClass.dashes(); + + /** + * Optionally folds Unicode whitespace and/or dashes in the input to their ASCII forms before + * inference, returning just the folded text. This is suitable for callers that do not map model + * output back to character offsets, such as whole-document classification. When the result must + * be mapped back to the original text (for example to report entity spans), use + * {@link #normalizeInputAligned(String, boolean, boolean)} instead, which also returns an + * {@link Alignment} that stays correct even when a fold changes the string length. + * + * @param text The input text. + * @param normalizeWhitespace Whether to fold whitespace to ASCII spaces. + * @param normalizeDashes Whether to fold dashes to the ASCII hyphen. + * @return The optionally normalized text. + */ + protected static String normalizeInput(final String text, final boolean normalizeWhitespace, + final boolean normalizeDashes) { + String result = text; + if (normalizeWhitespace) { + result = WHITESPACE.normalize(result).toString(); + } + if (normalizeDashes) { + result = DASHES.normalize(result).toString(); + } + return result; + } + + /** + * Like {@link #normalizeInput(String, boolean, boolean)} but also produces an {@link Alignment} + * from the folded text back to {@code text}, so model output positions map to original character + * offsets even when a fold changes the string length (a supplementary dash shrinking, or, for + * folds that may be added later, an expansion such as an ellipsis to three dots). + * + * @param text The input text. + * @param normalizeWhitespace Whether to fold whitespace to ASCII spaces. + * @param normalizeDashes Whether to fold dashes to the ASCII hyphen. + * @return The optionally normalized text paired with its alignment back to {@code text}. + */ + protected static AlignedText normalizeInputAligned(final String text, + final boolean normalizeWhitespace, final boolean normalizeDashes) { + // Compose each enabled fold's alignment with the running alignment so the returned mapping is + // correct no matter whether a stage changes length. Whitespace folding here is a one-for-one + // replacement and so is length-preserving today; only dash folding moves offsets (a + // supplementary-plane dash shrinks from two chars to one). Composing through andThen rather + // than relying on the whitespace stage staying length-preserving keeps findInOriginal() correct + // if that ever changes. + AlignedText result = identityAligned(text, text); + if (normalizeWhitespace) { + result = compose(result, WHITESPACE.normalizeAligned(result.normalized())); + } + if (normalizeDashes) { + result = compose(result, DASHES.normalizeAligned(result.normalized())); + } + return result; + } + + // Threads a fold stage onto the running alignment: accumulated maps original -> current and next + // maps current -> next.normalized(), so the composition maps original -> next.normalized(). + private static AlignedText compose(final AlignedText accumulated, final AlignedText next) { + return new AlignedText(accumulated.original(), next.normalized(), + accumulated.alignment().andThen(next.alignment())); + } + + // An AlignedText whose alignment is the identity, for the case where no length-changing fold was + // applied so the folded text has the same length and offsets as the original. + private static AlignedText identityAligned(final String original, final String normalized) { + final Alignment alignment = + new Alignment.Builder().equal(normalized.length()).build(normalized.length()); + return new AlignedText(original, normalized, alignment); + } + + /** + * Splits {@code text} on Unicode whitespace and groups the resulting tokens into overlapping + * chunks, each rejoined with single ASCII spaces, ready for WordPiece tokenization. The split + * uses the Unicode {@code White_Space} set, so spacing such as a no-break space or the + * ideographic space is recognized, and it yields no empty tokens from leading, trailing, or + * repeated whitespace. + * + * @param text The input text. + * @param documentSplitSize The maximum number of whitespace tokens per chunk. + * @param splitOverlapSize The number of tokens shared between consecutive chunks. + * @return The chunk strings, in order. + */ + protected static List whitespaceChunks(final String text, final int documentSplitSize, + final int splitOverlapSize) { + final List chunks = whitespaceChunkSpans(text, documentSplitSize, splitOverlapSize); + final List groups = new ArrayList<>(chunks.size()); + for (final TextChunk chunk : chunks) { + groups.add(chunk.text()); + } + return groups; + } + + /** + * Like {@link #whitespaceChunks(String, int, int)} but also carries each chunk's character span + * in {@code text}, so a chunk can be decoded bounded to the region it covers and overlapping + * chunks yield overlapping candidate spans rather than silently dropping a boundary entity. + * + * @param text The input text. + * @param documentSplitSize The maximum number of whitespace tokens per chunk. + * @param splitOverlapSize The number of tokens shared between consecutive chunks. + * @return The chunks, in order, each with its character span in {@code text}. + */ + protected static List whitespaceChunkSpans(final String text, + final int documentSplitSize, final int splitOverlapSize) { + final List tokenSpans = WHITESPACE.splitSpans(text); + final List chunks = new ArrayList<>(); + for (final ChunkRange range : chunkRanges(tokenSpans.size(), documentSplitSize, + splitOverlapSize)) { + final StringBuilder rejoined = new StringBuilder(); + for (int i = range.start(); i < range.end(); i++) { + if (i > range.start()) { + rejoined.append(' '); + } + rejoined.append(text, tokenSpans.get(i).getStart(), tokenSpans.get(i).getEnd()); + } + final int start = tokenSpans.get(range.start()).getStart(); + final int end = tokenSpans.get(range.end() - 1).getEnd(); + chunks.add(new TextChunk(rejoined.toString(), start, end)); } + return chunks; } /** @@ -340,7 +486,7 @@ protected static void validateSplitOptions(final int documentSplitSize, final in protected static List chunkRanges(final int tokenCount, final int documentSplitSize, final int splitOverlapSize) { if (tokenCount < 0) { - throw new IllegalArgumentException("tokenCount must not be negative."); + throw new IllegalArgumentException("The tokenCount must not be negative."); } validateSplitOptions(documentSplitSize, splitOverlapSize); diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java index 344c5846d..b59aac226 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/InferenceOptions.java @@ -26,6 +26,8 @@ public class InferenceOptions { private int documentSplitSize = 250; private int splitOverlapSize = 50; private Boolean lowerCase; + private boolean normalizeWhitespace; + private boolean normalizeDashes; public boolean isIncludeAttentionMask() { return includeAttentionMask; @@ -75,6 +77,46 @@ public void setSplitOverlapSize(int splitOverlapSize) { this.splitOverlapSize = splitOverlapSize; } + /** {@return whether input whitespace is normalized to ASCII spaces before inference} */ + public boolean isNormalizeWhitespace() { + return normalizeWhitespace; + } + + /** + * Replaces every Unicode whitespace character in the input with an ASCII space before inference. + * This is offset preserving (each whitespace code point maps to one space), so any spans a model + * produces still align with the input. Off by default. + * + *

This is a one-for-one replacement, not the collapse-and-trim whitespace fold of the runtime + * {@code TextNormalizer.whitespace()} rung: runs of whitespace are not merged and leading or + * trailing whitespace is not removed, so offsets are preserved.

+ * + * @param normalizeWhitespace Whether to normalize whitespace. + */ + public void setNormalizeWhitespace(boolean normalizeWhitespace) { + this.normalizeWhitespace = normalizeWhitespace; + } + + /** {@return whether input dashes are normalized to the ASCII hyphen before inference} */ + public boolean isNormalizeDashes() { + return normalizeDashes; + } + + /** + * Replaces Unicode dashes in the input with the ASCII hyphen-minus before inference. This is + * offset preserving for the dash characters in the Basic Multilingual Plane (the common case). + * The mathematical minus signs are not affected. Off by default. + * + *

A supplementary-plane dash shrinks from two chars to one, which shifts later offsets, so + * with this enabled {@code find(...)} reports offsets into the normalized text in that case. Use + * {@code NameFinderDL.findInOriginal(...)} for offsets mapped back to the original input.

+ * + * @param normalizeDashes Whether to normalize dashes. + */ + public void setNormalizeDashes(boolean normalizeDashes) { + this.normalizeDashes = normalizeDashes; + } + /** * Returns whether tokenization should lower case the input text and strip * accents, as required by uncased models. diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java index 7aa36e494..d8ecd3567 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/doccat/DocumentCategorizerDL.java @@ -85,6 +85,8 @@ public class DocumentCategorizerDL extends AbstractDL implements DocumentCategor private final boolean includeTokenTypeIds; private final int documentSplitSize; private final int splitOverlapSize; + private final boolean normalizeWhitespace; + private final boolean normalizeDashes; /** * Test-only constructor that injects an already-built {@link OrtSession} (or {@code null}), @@ -101,6 +103,8 @@ public class DocumentCategorizerDL extends AbstractDL implements DocumentCategor this.includeTokenTypeIds = inferenceOptions.isIncludeTokenTypeIds(); this.documentSplitSize = inferenceOptions.getDocumentSplitSize(); this.splitOverlapSize = inferenceOptions.getSplitOverlapSize(); + this.normalizeWhitespace = inferenceOptions.isNormalizeWhitespace(); + this.normalizeDashes = inferenceOptions.isNormalizeDashes(); } /** @@ -132,6 +136,8 @@ public DocumentCategorizerDL(File model, File vocabulary, Map c this.includeTokenTypeIds = inferenceOptions.isIncludeTokenTypeIds(); this.documentSplitSize = inferenceOptions.getDocumentSplitSize(); this.splitOverlapSize = inferenceOptions.getSplitOverlapSize(); + this.normalizeWhitespace = inferenceOptions.isNormalizeWhitespace(); + this.normalizeDashes = inferenceOptions.isNormalizeDashes(); } @@ -165,12 +171,15 @@ public DocumentCategorizerDL(File model, File vocabulary, File config, this.includeTokenTypeIds = inferenceOptions.isIncludeTokenTypeIds(); this.documentSplitSize = inferenceOptions.getDocumentSplitSize(); this.splitOverlapSize = inferenceOptions.getSplitOverlapSize(); + this.normalizeWhitespace = inferenceOptions.isNormalizeWhitespace(); + this.normalizeDashes = inferenceOptions.isNormalizeDashes(); } private static InferenceOptions validateConstructorArguments( final InferenceOptions inferenceOptions, final Object categoriesOrConfig, final ClassificationScoringStrategy classificationScoringStrategy) { + Objects.requireNonNull(inferenceOptions, "inferenceOptions"); Objects.requireNonNull(categoriesOrConfig, "categoriesOrConfig"); Objects.requireNonNull(classificationScoringStrategy, "classificationScoringStrategy"); return inferenceOptions; @@ -183,24 +192,41 @@ private static InferenceOptions validateConstructorArguments( * * @param strings The document to categorize; {@code strings[0]} is classified. * @return The per-category probabilities. - * @throws IllegalArgumentException If {@code strings} is {@code null} or empty. + * @throws IllegalArgumentException If {@code strings} is {@code null} or empty, or if + * {@code strings[0]} has no tokens to classify (it is empty or only whitespace). * @throws IllegalStateException If inference fails or the model returns an unexpected output. */ @Override public double[] categorize(String[] strings) { if (strings == null || strings.length == 0) { - throw new IllegalArgumentException("strings must contain at least one document to categorize"); + throw new IllegalArgumentException( + "The strings argument must contain at least one document to categorize"); } final List tokens = tokenize(strings[0]); + if (tokens.isEmpty()) { + throw new IllegalArgumentException( + "The document to categorize must contain at least one non-whitespace token"); + } final List scores = new LinkedList<>(); for (final Tokens t : tokens) { scores.add(softmax(infer(t))); } - return classificationScoringStrategy.score(scores); + final double[] distribution = classificationScoringStrategy.score(scores); + return requireMatchingCategoryCount(distribution, categories.size()); + } + + // Package-visible so the model/category-count mismatch guard can be exercised without a live model. + static double[] requireMatchingCategoryCount(final double[] distribution, final int expected) { + if (distribution.length != expected) { + throw new IllegalStateException("The model produced " + distribution.length + + " category scores but the categorizer is configured with " + expected + + " categories; the model and the category configuration do not match"); + } + return distribution; } /** @@ -236,6 +262,12 @@ private float[] infer(final Tokens t) { inputs.values().forEach(OnnxTensor::close); } + return logitsFromOutput(output); + } + + // Package-visible so the output-shape dispatch, including the null and unexpected-type failures, + // can be exercised without a live model session. + static float[] logitsFromOutput(final Object output) { // Some models return a 2D array (e.g. BERT), others a 1D array (e.g. RoBERTa). A different // shape is a model-contract violation, surfaced on its own rather than as "inference failed". if (output instanceof float[][] v) { @@ -243,7 +275,8 @@ private float[] infer(final Tokens t) { } else if (output instanceof float[] v) { return v; } - throw new IllegalStateException("Unexpected model output type: " + output.getClass().getName()); + throw new IllegalStateException("Unexpected model output type: " + + (output == null ? "null" : output.getClass().getName())); } @Override @@ -327,21 +360,15 @@ private int getKey(String value) { } - private List tokenize(final String text) { + private List tokenize(final String input) { + final String text = normalizeInput(input, normalizeWhitespace, normalizeDashes); final List t = new LinkedList<>(); - // Segment long input text into overlapping chunks configured by InferenceOptions before - // feeding each chunk into BERT. + // Segment long input text into overlapping chunks (split on Unicode whitespace) configured by + // InferenceOptions before feeding each chunk into BERT. // https://medium.com/analytics-vidhya/text-classification-with-bert-using-transformers-for-long-text-inputs-f54833994dfd - final String[] whitespaceTokenized = text.split("\\s+"); - - for (ChunkRange chunkRange : chunkRanges( - whitespaceTokenized.length, documentSplitSize, splitOverlapSize)) { - - // The group is that subsection of string. - final String group = String.join(" ", - Arrays.copyOfRange(whitespaceTokenized, chunkRange.start(), chunkRange.end())); + for (final String group : whitespaceChunks(text, documentSplitSize, splitOverlapSize)) { // Now we can tokenize the group and continue. final String[] tokens = tokenizer.tokenize(group); @@ -402,6 +429,14 @@ static double[] softmax(final float[] input) { // identical to the naive form. Results are kept in double precision throughout. double max = Double.NEGATIVE_INFINITY; for (final float value : input) { + // Reject any non-finite logit, not just NaN: a +Infinity logit makes max == +Inf, so + // value - max is Inf - Inf == NaN and the whole distribution silently goes NaN. Subtracting + // the maximum already handles merely-large finite logits, so only NaN/Infinity reach here. + if (!Float.isFinite(value)) { + throw new IllegalStateException( + "the model produced a non-finite logit (NaN or Infinity); cannot compute a " + + "classification distribution"); + } max = Math.max(max, value); } diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java index e5b5c89b5..9eaa80e2f 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java @@ -22,14 +22,14 @@ import java.nio.LongBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.regex.Matcher; -import java.util.regex.Pattern; +import java.util.TreeMap; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtException; @@ -41,12 +41,15 @@ import opennlp.dl.InferenceOptions; import opennlp.dl.Tokens; import opennlp.tools.commons.ThreadSafe; -import opennlp.tools.namefind.TokenNameFinder; +import opennlp.tools.namefind.OffsetMappingNameFinder; import opennlp.tools.sentdetect.SentenceDetector; +import opennlp.tools.tokenize.WordpieceTokenizer; import opennlp.tools.util.Span; +import opennlp.tools.util.normalizer.AlignedText; +import opennlp.tools.util.normalizer.Alignment; /** - * An implementation of {@link TokenNameFinder} that uses ONNX models. + * An implementation of {@link opennlp.tools.namefind.TokenNameFinder} that uses ONNX models. * *

Tokenization performs BERT basic tokenization (text normalization) * before wordpiece, see {@link opennlp.tools.tokenize.BertTokenizer}. Input @@ -64,18 +67,20 @@ * guarantee applies until {@link #close()} is called; callers must not race * {@code close()} with inference methods.

* - * @see TokenNameFinder + * @see opennlp.tools.namefind.TokenNameFinder * @see InferenceOptions */ @ThreadSafe -public class NameFinderDL extends AbstractDL implements TokenNameFinder { +public class NameFinderDL extends AbstractDL implements OffsetMappingNameFinder { - /** Example person labels; retained for reference. Decoding handles any B-/I- type. */ - public static final String I_PER = "I-PER"; - public static final String B_PER = "B-PER"; public static final String SEPARATOR = "[SEP]"; public static final String CLS_TOKEN = "[CLS]"; + // Tokenizer-added markers (BERT and RoBERTa) that must never appear in a reconstructed span. + private static final Set SPECIAL_TOKENS = Set.of( + CLS_TOKEN, SEPARATOR, + WordpieceTokenizer.ROBERTA_CLS_TOKEN, WordpieceTokenizer.ROBERTA_SEP_TOKEN); + /** Prefix used by BIO labels for the first token in an entity span. */ public static final String PREFIX_BEGIN = "B-"; @@ -104,9 +109,11 @@ public class NameFinderDL extends AbstractDL implements TokenNameFinder { private final boolean includeTokenTypeIds; private final int documentSplitSize; private final int splitOverlapSize; + private final boolean normalizeWhitespace; + private final boolean normalizeDashes; /** - * Instantiates a {@link TokenNameFinder name finder} using ONNX models. + * Instantiates a {@link opennlp.tools.namefind.TokenNameFinder name finder} using ONNX models. * * @param model The ONNX model file. * @param vocabulary The model file's vocabulary file. @@ -126,7 +133,7 @@ public NameFinderDL(File model, File vocabulary, Map ids2Labels } /** - * Instantiates a {@link TokenNameFinder name finder} using ONNX models. + * Instantiates a {@link opennlp.tools.namefind.TokenNameFinder name finder} using ONNX models. * * @param model The ONNX model file. * @param vocabulary The model file's vocabulary file. @@ -153,6 +160,8 @@ public NameFinderDL(File model, File vocabulary, Map ids2Labels this.includeTokenTypeIds = inferenceOptions.isIncludeTokenTypeIds(); this.documentSplitSize = inferenceOptions.getDocumentSplitSize(); this.splitOverlapSize = inferenceOptions.getSplitOverlapSize(); + this.normalizeWhitespace = inferenceOptions.isNormalizeWhitespace(); + this.normalizeDashes = inferenceOptions.isNormalizeDashes(); this.sentenceDetector = sentenceDetector; } @@ -160,6 +169,7 @@ public NameFinderDL(File model, File vocabulary, Map ids2Labels private static InferenceOptions validateConstructorArguments( final InferenceOptions inferenceOptions, final Map ids2Labels, final SentenceDetector sentenceDetector) { + Objects.requireNonNull(inferenceOptions, "inferenceOptions"); Objects.requireNonNull(ids2Labels, "ids2Labels"); Objects.requireNonNull(sentenceDetector, "sentenceDetector"); return inferenceOptions; @@ -168,9 +178,16 @@ private static InferenceOptions validateConstructorArguments( /** * {@inheritDoc} * - *

This method joins the provided tokens with spaces, sentence-splits the joined text, - * runs each sentence through the ONNX token-classification model, decodes BIO labels into - * {@link Span spans}, and resolves those spans back to character offsets in the joined text.

+ *

Joins the provided tokens with spaces, sentence-splits the joined text, runs each sentence + * through the ONNX token-classification model, decodes BIO labels into {@link Span spans}, and + * resolves those spans to character offsets in the joined text after any optional input + * normalization.

+ * + *

Note: this returns correct original offsets in every case except one. Whitespace folding is + * length-preserving, so it never moves offsets. Only dash folding can change the input length, and + * only for a non-BMP dash; so when {@code normalizeDashes} is enabled and the input contains a + * supplementary-plane dash, the returned spans are offsets into the normalized text rather than + * the original. For an exact original mapping in that case, use {@link #findInOriginal(String[])}.

* * @throws IllegalStateException Thrown if inference fails, if the model output shape is not * the expected {@code float[batch][token][label]} form, if the model output contains @@ -178,44 +195,131 @@ private static InferenceOptions validateConstructorArguments( * present in the configured label map. * @throws IllegalArgumentException Thrown if a token produced for the input is not present in * the vocabulary, which indicates the vocabulary file does not match the model. + * @throws NullPointerException Thrown if {@code input} is {@code null}. */ @Override public Span[] find(String[] input) { + return locate(input).spans().toArray(new Span[0]); + } - final List spans = new ArrayList<>(); + /** + * Finds names and returns their {@link Span spans} in coordinates of the original joined input + * ({@code String.join(" ", input)}), regardless of any whitespace or dash normalization applied + * before inference. Spans are mapped back through the normalization {@link Alignment}, so a fold + * that changes the input length (a supplementary dash shrinking, or an expansion) does not shift + * the reported offsets. This implements {@link OffsetMappingNameFinder}, so an interface-typed + * caller can reach the offset-correct path with + * {@code finder instanceof OffsetMappingNameFinder}. + * + * @param input The tokens to search. + * @return The detected spans, in original-input character coordinates. + * @throws IllegalStateException Thrown under the same conditions as {@link #find(String[])}. + * @throws IllegalArgumentException Thrown under the same conditions as {@link #find(String[])}. + */ + @Override + public Span[] findInOriginal(String[] input) { + final DecodedSpans decoded = locate(input); + final Alignment alignment = decoded.aligned().alignment(); + final List mapped = new ArrayList<>(decoded.spans().size()); + for (final Span span : decoded.spans()) { + final Span original = alignment.toOriginalSpan(span.getStart(), span.getEnd()); + mapped.add(new Span(original.getStart(), original.getEnd(), span.getType(), span.getProb())); + } + return mapped.toArray(new Span[0]); + } + + // Shared core: normalize the joined input (capturing the alignment back to the original), then + // decode each overlapping chunk bounded to its own character region and resolve overlaps. Bounding + // per chunk lets a boundary entity that two consecutive chunks both cover surface as overlapping + // candidates, which mergeOverlappingSpans collapses to the longer (more complete) span instead of + // silently keeping whichever a single forward cursor reached first. + private DecodedSpans locate(String[] input) { + + Objects.requireNonNull(input, "input"); // Join the tokens here because they will be tokenized using Wordpiece during inference. - final String text = String.join(" ", input); + final AlignedText normalized = + normalizeInputAligned(String.join(" ", input), normalizeWhitespace, normalizeDashes); + final String text = normalized.normalizedString(); // sentPosDetect (not sentDetect) so each sentence's offset in the full text is known. final Span[] sentenceSpans = sentenceDetector.sentPosDetect(text); + final List candidates = new ArrayList<>(); for (final Span sentenceSpan : sentenceSpans) { - // Floor the character cursor at this sentence's start, then thread it forward across the - // sentence's chunks so a repeated surface form is located at its next occurrence. Flooring - // per sentence keeps an entity from being matched against an identical surface form in an - // earlier sentence -- even one that produced no spans, which would otherwise leave the - // cursor behind and mis-locate the match. - int searchStart = sentenceSpan.getStart(); - - // The WordPiece tokenized text. This changes the spacing in the text. - final List wordpieceTokens = tokenize(sentenceSpan.getCoveredText(text).toString()); - - for (final Tokens tokens : wordpieceTokens) { - final List decoded = - decodeSpans(text, tokens.tokens(), infer(tokens), ids2Labels, searchStart, - sentenceSpan.getEnd()); - spans.addAll(decoded); - if (!decoded.isEmpty()) { - searchStart = decoded.get(decoded.size() - 1).getEnd(); - } + final int sentenceStart = sentenceSpan.getStart(); + final String sentence = sentenceSpan.getCoveredText(text).toString(); + + // The WordPiece tokenized text, in overlapping chunks. This changes the spacing in the text. + for (final ChunkTokens chunk : tokenize(sentence)) { + // Decode within the chunk's own character region in the full text. Keeping each chunk's + // entities inside the region it was built from locates a repeated surface form in the right + // chunk rather than mis-matching it to an earlier occurrence, while still letting two + // overlapping chunks both emit a boundary entity for mergeOverlappingSpans to reconcile. + final int regionStart = sentenceStart + chunk.start(); + final int regionEnd = sentenceStart + chunk.end(); + candidates.addAll(decodeSpans(text, chunk.tokens().tokens(), infer(chunk.tokens()), + ids2Labels, regionStart, regionEnd)); } } - return spans.toArray(new Span[0]); + return new DecodedSpans(mergeOverlappingSpans(candidates), normalized); + } + + private record DecodedSpans(List spans, AlignedText aligned) { + } + // A chunk's WordPiece tokens paired with the chunk's half-open character span in the full text. + private record ChunkTokens(Tokens tokens, int start, int end) { + } + + // Ordering for overlap resolution: longest span first, then higher probability. The dominant + // detection is kept and any later span overlapping it is dropped. + private static final Comparator BY_LENGTH_THEN_PROBABILITY = + Comparator.comparingInt(Span::length).reversed() + .thenComparing(Comparator.comparingDouble(Span::getProb).reversed()); + + /** + * Resolves spans that overlap in character coordinates, as happens when an entity falls in the + * shared region of two consecutive overlapping chunks and is decoded by both. The longer span is + * kept (the more complete decode) and ties break toward the higher probability; any span that + * overlaps an already kept one is dropped. Adjacent but disjoint spans are never merged, so + * neighbouring distinct entities and repeated surface forms at different offsets are preserved. + * The returned list is in document order. The choice is length-dominant rather than type-aware: + * when two overlapping spans carry different entity types, the longer still wins regardless of + * type, which is the intended heuristic for the rare cross-type overlap at a chunk boundary. + * + * @param spans The decoded candidate spans, in the order they were produced. + * @return The overlap-free spans, ordered by start offset. + */ + static List mergeOverlappingSpans(final List spans) { + if (spans.size() < 2) { + // Return a fresh list so the caller always owns the result, matching the >= 2 path below + // (which returns a new list); the input is never handed back aliased. + return new ArrayList<>(spans); + } + final List byDominance = new ArrayList<>(spans); + byDominance.sort(BY_LENGTH_THEN_PROBABILITY); + // Kept spans never overlap each other, so they form a start-sorted partition. A candidate can + // only intersect the kept span that starts at or just before it (floor) or the next kept span + // that starts within it (ceiling); checking those two is O(log n) instead of scanning every kept + // span, making the whole longest-wins pass O(n log n) rather than O(n^2). Keyed by start, the map + // also yields the result already in document order. + final TreeMap kept = new TreeMap<>(); + for (final Span candidate : byDominance) { + final Map.Entry before = kept.floorEntry(candidate.getStart()); + if (before != null && before.getValue().getEnd() > candidate.getStart()) { + continue; + } + final Map.Entry after = kept.ceilingEntry(candidate.getStart()); + if (after != null && after.getKey() < candidate.getEnd()) { + continue; + } + kept.put(candidate.getStart(), candidate); + } + return new ArrayList<>(kept.values()); } /** @@ -356,7 +460,7 @@ static List decodeSpans(String text, String[] tokens, float[][] tokenLabel continue; } - final SpanMatch match = findByRegex(text, spanText, characterStart, searchEnd); + final SpanMatch match = findInSource(text, spanText, characterStart, searchEnd); if (match.start() != -1) { spans.add(new Span(match.start(), match.end(), entityType, entity.probability())); characterStart = match.end(); @@ -487,9 +591,9 @@ static double labelProbability(float[] scores, int labelIndex) { /** * Reconstructs source-like text from a span of WordPiece tokens. * - *

Special BERT tokens are skipped, {@code ##} continuations are merged into the preceding - * surface form, and simple punctuation spacing is normalized so the result can be located in - * the caller's original text.

+ *

Special BERT and RoBERTa tokens are skipped, {@code ##} continuations are merged into the + * preceding surface form, and simple punctuation spacing is normalized so the result can be + * located in the caller's original text.

* * @param tokens The WordPiece token sequence. * @param startIndex The first token index to include. @@ -503,7 +607,7 @@ static String buildSpanText(String[] tokens, int startIndex, int endIndex) { for (int x = startIndex; x <= endIndex && x < tokens.length; x++) { final String token = tokens[x]; - if (CLS_TOKEN.equals(token) || SEPARATOR.equals(token)) { + if (SPECIAL_TOKENS.contains(token)) { continue; } @@ -567,35 +671,82 @@ private static int maxIndex(float[] arr) { /** * Locates reconstructed span text in a bounded region of the original input text. * + *

Matching is a single forward cursor scan, not a regular expression. Each space in the + * reconstructed span matches a run of zero or more Unicode whitespace characters in the source + * (so an entity whose WordPiece pieces were rejoined with spaces, such as {@code "AT & T"} for + * {@code "AT&T"}, is still located), and every other code point matches case-insensitively. + * Using a cursor avoids {@link java.util.regex.Pattern}/{@link java.util.regex.Matcher} + * allocation and the ReDoS surface of regular expressions, and recognizes Unicode whitespace + * that Java's {@code \s} does not.

+ * * @param text The original text. - * @param span The reconstructed span text. + * @param span The reconstructed span text, with sub-tokens separated by single ASCII spaces. * @param searchStart The first character offset to search from. * @param searchEnd The exclusive upper bound of the region to search. * @return The matched character offsets, or {@code (-1, -1)} when the reconstructed text * cannot be found in the requested region. */ - private static SpanMatch findByRegex(String text, String span, int searchStart, int searchEnd) { - - // Reconstructed span text normalizes whitespace, so match flexibly: a space in the span may - // map to any run of whitespace OR none in the source (e.g. punctuation/'&' inside "U.S.A", - // "AT&T" that wordpiece tokenization split apart). Use \s* rather than \s+ so such entities - // are still located instead of being silently dropped. - final String regex = Pattern.quote(span).replace(" ", "\\E\\s*\\Q"); + private static SpanMatch findInSource(String text, String span, int searchStart, int searchEnd) { - final Pattern pattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE); - final Matcher matcher = pattern.matcher(text); final int regionStart = Math.min(Math.max(searchStart, 0), text.length()); final int regionEnd = Math.min(Math.max(searchEnd, regionStart), text.length()); - matcher.region(regionStart, regionEnd); - if (matcher.find()) { - return new SpanMatch(matcher.start(), matcher.end()); + int start = regionStart; + while (start < regionEnd) { + final int end = matchAt(text, span, start, regionEnd); + if (end != -1) { + return new SpanMatch(start, end); + } + start += Character.charCount(text.codePointAt(start)); } return new SpanMatch(-1, -1); } + /** + * Attempts to match {@code span} against {@code text} beginning at {@code start} and bounded by + * {@code regionEnd}. A space in {@code span} consumes a run of zero or more Unicode whitespace + * code points in the source; every other code point must match case-insensitively. + * + * @return The exclusive end offset of the match in {@code text}, or {@code -1} if no match + * begins at {@code start}. + */ + private static int matchAt(String text, String span, int start, int regionEnd) { + + int t = start; + int s = 0; + + while (s < span.length()) { + final int spanCp = span.codePointAt(s); + if (spanCp == ' ') { + while (t < regionEnd && WHITESPACE.contains(text.codePointAt(t))) { + t += Character.charCount(text.codePointAt(t)); + } + s += 1; + } else { + if (t >= regionEnd) { + return -1; + } + final int textCp = text.codePointAt(t); + if (!equalsIgnoreCase(spanCp, textCp)) { + return -1; + } + t += Character.charCount(textCp); + s += Character.charCount(spanCp); + } + } + + return t; + + } + + private static boolean equalsIgnoreCase(int a, int b) { + return a == b + || Character.toLowerCase(a) == Character.toLowerCase(b) + || Character.toUpperCase(a) == Character.toUpperCase(b); + } + private record LabelPrediction(String label, double probability) { } @@ -609,24 +760,18 @@ private record EntityPrediction(int endIndex, double probability) { private record SpanMatch(int start, int end) { } - private List tokenize(final String text) { + private List tokenize(final String text) { - final List t = new LinkedList<>(); + final List t = new LinkedList<>(); - // Segment long input text into overlapping chunks configured by InferenceOptions before - // feeding each chunk into BERT. + // Segment long input text into overlapping chunks (split on Unicode whitespace) configured by + // InferenceOptions before feeding each chunk into BERT, keeping each chunk's character span so + // its decoded spans can be bounded to the region the chunk covers. // https://medium.com/analytics-vidhya/text-classification-with-bert-using-transformers-for-long-text-inputs-f54833994dfd - final String[] whitespaceTokenized = text.split("\\s+"); - - for (ChunkRange chunkRange : chunkRanges( - whitespaceTokenized.length, documentSplitSize, splitOverlapSize)) { - - // The group is that subsection of string. - final String group = String.join(" ", - Arrays.copyOfRange(whitespaceTokenized, chunkRange.start(), chunkRange.end())); + for (final TextChunk chunk : whitespaceChunkSpans(text, documentSplitSize, splitOverlapSize)) { // Now we can tokenize the group and continue. - final String[] tokens = tokenizer.tokenize(group); + final String[] tokens = tokenizer.tokenize(chunk.text()); final long[] ids = tokenIds(tokens, vocab); @@ -636,7 +781,7 @@ private List tokenize(final String text) { final long[] types = new long[ids.length]; Arrays.fill(types, 0); - t.add(new Tokens(tokens, ids, mask, types)); + t.add(new ChunkTokens(new Tokens(tokens, ids, mask, types), chunk.start(), chunk.end())); } diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/AbstractDLChunkingTest.java b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/AbstractDLChunkingTest.java new file mode 100644 index 000000000..13957ea9b --- /dev/null +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/AbstractDLChunkingTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package opennlp.dl; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import opennlp.tools.util.Span; +import opennlp.tools.util.normalizer.AlignedText; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Model-free tests for {@link AbstractDL#whitespaceChunks(String, int, int)} and + * {@link AbstractDL#whitespaceChunkSpans(String, int, int)}, the shared tokenize-and-chunk seam + * used by both {@code NameFinderDL} and {@code DocumentCategorizerDL}. + */ +public class AbstractDLChunkingTest { + + @Test + void testSplitsOnUnicodeWhitespaceNotJustAscii() { + // A no-break space (U+00A0) and an ideographic space (U+3000) are not matched by Java's \s + // but must still separate tokens; the chunk is rejoined with single ASCII spaces. + final String nbsp = new String(Character.toChars(0x00A0)); + final String ideographic = new String(Character.toChars(0x3000)); + assertEquals(List.of("alpha beta gamma"), + AbstractDL.whitespaceChunks("alpha" + nbsp + "beta" + ideographic + "gamma", 100, 0)); + } + + @Test + void testDropsEmptyTokensFromLeadingTrailingAndRepeatedWhitespace() { + // Unlike split("\\s+"), the Unicode-aware split yields no empty leading or trailing tokens. + assertEquals(List.of("a b c"), AbstractDL.whitespaceChunks(" a b\tc ", 100, 0)); + } + + @Test + void testAppliesChunkSizeWithoutOverlap() { + assertEquals(List.of("a b", "c d"), AbstractDL.whitespaceChunks("a b c d", 2, 0)); + } + + @Test + void testAppliesChunkOverlap() { + assertEquals(List.of("a b", "b c", "c d"), AbstractDL.whitespaceChunks("a b c d", 2, 1)); + } + + @Test + void testEmptyTextYieldsNoChunks() { + assertEquals(List.of(), AbstractDL.whitespaceChunks("", 100, 0)); + } + + @Test + void testNormalizeInputIsOptInAndOffsetPreserving() { + final String nbsp = new String(Character.toChars(0x00A0)); + final String emDash = new String(Character.toChars(0x2014)); + final String input = "a" + nbsp + "b" + emDash + "c"; + + // Off by default: unchanged. + assertEquals(input, AbstractDL.normalizeInput(input, false, false)); + + // Whitespace only: the no-break space becomes a space, and the length is preserved. + final String ws = AbstractDL.normalizeInput(input, true, false); + assertEquals("a b" + emDash + "c", ws); + assertEquals(input.length(), ws.length()); + + // Dashes only: the em dash becomes an ASCII hyphen. + assertEquals("a" + nbsp + "b-c", AbstractDL.normalizeInput(input, false, true)); + + // Both. + assertEquals("a b-c", AbstractDL.normalizeInput(input, true, true)); + } + + @Test + void testNormalizeInputAlignedMapsSpanBackAcrossSupplementaryDash() { + // The aligned path used by NameFinderDL.findInOriginal: a span found in the folded text maps + // back to a tight original span even though the supplementary dash shrank by one unit. + final String yezidi = new String(Character.toChars(0x10EAD)); + final String input = "a" + yezidi + "b"; // a(0), dash(1,2), b(3) -> length 4 + final AlignedText at = AbstractDL.normalizeInputAligned(input, false, true); + assertEquals("a-b", at.normalized()); + + final Span b = at.toOriginalSpan(2, 3); // "b" in the folded text + assertEquals(3, b.getStart()); // maps back to original offset 3, not shifted to 2 + assertEquals(4, b.getEnd()); + + final Span hyphen = at.toOriginalSpan(1, 2); // the folded hyphen covers the two-unit dash + assertEquals(1, hyphen.getStart()); + assertEquals(3, hyphen.getEnd()); + } + + @Test + void testNormalizeInputAlignedIsIdentityWhenLengthPreserved() { + // Whitespace folding is length-preserving, so the alignment maps every position to itself. + final AlignedText at = AbstractDL.normalizeInputAligned("a\tb", true, false); + assertEquals("a b", at.normalized()); + for (int i = 0; i < 3; i++) { + final Span s = at.toOriginalSpan(i, i + 1); + assertEquals(i, s.getStart()); + assertEquals(i + 1, s.getEnd()); + } + } + + @Test + void testNormalizeInputAlignedComposesWhitespaceAndDash() { + // With both folds enabled the alignment is composed across stages, so a span found in the fully + // folded text still maps back to the original even though the supplementary dash shrank a unit. + final String nbsp = new String(Character.toChars(0x00A0)); + final String yezidi = new String(Character.toChars(0x10EAD)); + final String input = "a" + nbsp + "b" + yezidi + "c"; // a(0) nbsp(1) b(2) dash(3,4) c(5) + final AlignedText at = AbstractDL.normalizeInputAligned(input, true, true); + assertEquals("a b-c", at.normalized()); // a(0) space(1) b(2) hyphen(3) c(4) + + final Span c = at.toOriginalSpan(4, 5); // "c" maps past the two-unit dash + assertEquals(5, c.getStart()); + assertEquals(6, c.getEnd()); + + final Span hyphen = at.toOriginalSpan(3, 4); // folded hyphen covers the supplementary dash + assertEquals(3, hyphen.getStart()); + assertEquals(5, hyphen.getEnd()); + + final Span space = at.toOriginalSpan(1, 2); // the no-break space, length preserved + assertEquals(1, space.getStart()); + assertEquals(2, space.getEnd()); + } + + @Test + void testWhitespaceChunkSpansCarryCharacterOffsets() { + final String text = "a b c d"; // tokens a[0,1) b[2,3) c[4,5) d[6,7) + final List chunks = AbstractDL.whitespaceChunkSpans(text, 2, 1); + + assertEquals(3, chunks.size()); + assertEquals("a b", chunks.get(0).text()); + assertEquals(0, chunks.get(0).start()); + assertEquals(3, chunks.get(0).end()); + assertEquals("b c", chunks.get(1).text()); + assertEquals(2, chunks.get(1).start()); + assertEquals(5, chunks.get(1).end()); + assertEquals("c d", chunks.get(2).text()); + assertEquals(4, chunks.get(2).start()); + assertEquals(7, chunks.get(2).end()); + + // For single-ASCII-space input the reported span covers the chunk's rejoined text exactly. + for (final AbstractDL.TextChunk chunk : chunks) { + assertEquals(chunk.text(), text.substring(chunk.start(), chunk.end())); + } + } +} diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java index 80087d97c..3e8239114 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLTest.java @@ -85,6 +85,25 @@ void testCategorizeRejectsMalformedInput() { assertThrows(IllegalArgumentException.class, () -> categorizer.categorize(new String[0])); } + @Test + void testCategorizeRejectsTokenlessContent() { + // A non-empty array whose document has no tokens (empty or only whitespace, including Unicode + // whitespace) is rejected up front rather than crashing downstream on an empty score list. + final DocumentCategorizerDL categorizer = categorizerWithoutSession(); + final String nbsp = new String(Character.toChars(0x00A0)); + + assertThrows(IllegalArgumentException.class, () -> categorizer.categorize(new String[] {""})); + assertThrows(IllegalArgumentException.class, () -> categorizer.categorize(new String[] {" "})); + assertThrows(IllegalArgumentException.class, () -> categorizer.categorize(new String[] {nbsp})); + } + + @Test + void testConstructorRejectsNullInferenceOptions() { + assertThrows(NullPointerException.class, () -> + new DocumentCategorizerDL(null, null, vocab(), categories(), + new AverageClassificationScoringStrategy(), null)); + } + @Test void testTokenIdsMapsTokensToVocabularyIds() { final long[] ids = DocumentCategorizerDL.tokenIds( @@ -103,6 +122,27 @@ void testTokenIdsRejectsTokensMissingFromVocabulary() { "the error message should name the missing token: " + e.getMessage()); } + @Test + void testSoftmaxRejectsNaNLogit() { + // A NaN logit would otherwise poison the whole distribution into NaN scores; fail loudly instead. + final IllegalStateException e = assertThrows(IllegalStateException.class, () -> + DocumentCategorizerDL.softmax(new float[] {0f, Float.NaN, 0f})); + assertTrue(e.getMessage().contains("NaN"), e.getMessage()); + } + + @Test + void testSoftmaxRejectsInfiniteLogit() { + // A +Infinity logit (not NaN, so it slips past an isNaN-only guard) poisons the distribution too: + // max becomes +Inf, so value - max is Inf - Inf == NaN, every exp() is NaN, and categorize() would + // silently return all-NaN scores. It must fail loud like the NaN case. -Infinity is non-finite too. + final IllegalStateException pos = assertThrows(IllegalStateException.class, () -> + DocumentCategorizerDL.softmax(new float[] {0f, Float.POSITIVE_INFINITY, 0f})); + assertTrue(pos.getMessage().contains("non-finite") || pos.getMessage().contains("Infinity"), + pos.getMessage()); + assertThrows(IllegalStateException.class, () -> + DocumentCategorizerDL.softmax(new float[] {0f, Float.NEGATIVE_INFINITY, 0f})); + } + @Test void testSoftmaxIsUniformForEqualLogitsAndSumsToOne() { final double[] out = DocumentCategorizerDL.softmax(new float[] {0f, 0f, 0f}); @@ -146,4 +186,34 @@ void testSoftmaxMatchesReferenceDistribution() { assertEquals(0.24472847, out[1], 1e-6); assertEquals(0.66524096, out[2], 1e-6); } + + @Test + void testLogitsFromOutputDispatchesOnModelShape() { + // A 2D output (BERT-style) takes row 0; a 1D output (RoBERTa-style) is taken as-is. + assertArrayEquals(new float[] {1f, 2f}, + DocumentCategorizerDL.logitsFromOutput(new float[][] {{1f, 2f}})); + assertArrayEquals(new float[] {3f, 4f}, + DocumentCategorizerDL.logitsFromOutput(new float[] {3f, 4f})); + } + + @Test + void testLogitsFromOutputFailsLoudlyOnNullAndUnexpectedType() { + // A null or otherwise-shaped model output is a contract violation, not an "inference failed". + final IllegalStateException onNull = assertThrows(IllegalStateException.class, + () -> DocumentCategorizerDL.logitsFromOutput(null)); + assertTrue(onNull.getMessage().contains("null"), onNull.getMessage()); + assertThrows(IllegalStateException.class, + () -> DocumentCategorizerDL.logitsFromOutput("not a tensor")); + } + + @Test + void testRequireMatchingCategoryCountFailsLoudlyOnMismatch() { + // A distribution whose length differs from the configured category count means the model and + // the categorizer configuration do not match; the matching case passes the array through. + final double[] ok = {0.5, 0.5}; + assertArrayEquals(ok, DocumentCategorizerDL.requireMatchingCategoryCount(ok, 2)); + final IllegalStateException e = assertThrows(IllegalStateException.class, + () -> DocumentCategorizerDL.requireMatchingCategoryCount(new double[] {1.0}, 2)); + assertTrue(e.getMessage().contains("do not match"), e.getMessage()); + } } diff --git a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLTest.java b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLTest.java index c0a8aede2..4bf698990 100644 --- a/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLTest.java +++ b/opennlp-core/opennlp-ml/opennlp-dl/src/test/java/opennlp/dl/namefinder/NameFinderDLTest.java @@ -17,6 +17,7 @@ package opennlp.dl.namefinder; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -28,6 +29,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -169,7 +171,7 @@ void testDecodeSpansSearchStartLocatesNextOccurrence() { void testDecodeSpansLocatesEntityWithInternalPunctuation() { // WordPiece splits "AT&T" into separate AT / & / T tokens, so the reconstructed span text // ("AT & T") must still be located in the contiguous source. Regression guard for the - // flexible-whitespace (\s*) matching in findByRegex. + // flexible-whitespace matching in findInSource (a span space matches zero source whitespace). final String text = "Buy AT&T stock"; final String[] tokens = {"[CLS]", "Buy", "AT", "&", "T", "stock", "[SEP]"}; final float[][] scores = { @@ -184,6 +186,37 @@ void testDecodeSpansLocatesEntityWithInternalPunctuation() { assertEquals("AT&T", spans.get(0).getCoveredText(text)); } + @Test + void testDecodeSpansMatchesEntitySeparatedByNoBreakSpace() { + // The source separates "New" and "York" with a no-break space (U+00A0). Java's \s does not + // match it, so the previous regex matcher would have dropped this LOC span; the Unicode-aware + // cursor matcher locates it and the covered text includes the no-break space. + final String nbsp = new String(Character.toChars(0x00A0)); + final String text = "Visit New" + nbsp + "York today"; + final String[] tokens = {"[CLS]", "New", "York", "[SEP]"}; + final float[][] scores = {scoresFor(0), scoresFor(3), scoresFor(4), scoresFor(0)}; + + final List spans = NameFinderDL.decodeSpans(text, tokens, scores, ID_TO_LABELS); + + assertEquals(1, spans.size()); + assertEquals("LOC", spans.get(0).getType()); + assertEquals("New" + nbsp + "York", spans.get(0).getCoveredText(text)); + } + + @Test + void testDecodeSpansMatchesEntitySeparatedByIdeographicSpace() { + // Same idea with the CJK ideographic space (U+3000), another character outside Java's \s. + final String ideographic = new String(Character.toChars(0x3000)); + final String text = "from New" + ideographic + "York city"; + final String[] tokens = {"[CLS]", "New", "York", "[SEP]"}; + final float[][] scores = {scoresFor(0), scoresFor(3), scoresFor(4), scoresFor(0)}; + + final List spans = NameFinderDL.decodeSpans(text, tokens, scores, ID_TO_LABELS); + + assertEquals(1, spans.size()); + assertEquals("New" + ideographic + "York", spans.get(0).getCoveredText(text)); + } + @Test void testDecodeSpansDoesNotMatchBeyondSearchEnd() { final String text = "London was quiet. Later Paris was loud."; @@ -199,7 +232,7 @@ void testDecodeSpansDoesNotMatchBeyondSearchEnd() { @Test void testDecodeSpansMatchesSourceCaseInsensitively() { // The reconstructed span text may differ in case from the source (e.g. an uncased model); - // findByRegex matches case-insensitively, so the span is still located at the source offsets. + // findInSource matches case-insensitively, so the span is still located at the source offsets. final String text = "Visit PARIS today"; final String[] tokens = {"[CLS]", "Visit", "paris", "today", "[SEP]"}; final float[][] scores = { @@ -324,6 +357,72 @@ void testDecodeSpansEmitsRepeatedEntityAtDistinctOffsets() { assertEquals(15, spans.get(1).getEnd()); } + @Test + void testMergeOverlappingSpansKeepsLongestAndPreservesDisjoint() { + // Containment: the longer span absorbs the shorter overlapping one. + final List contained = NameFinderDL.mergeOverlappingSpans(new ArrayList<>(List.of( + new Span(0, 8, "LOC", 0.9), new Span(0, 13, "LOC", 0.8)))); + assertEquals(1, contained.size()); + assertEquals(0, contained.get(0).getStart()); + assertEquals(13, contained.get(0).getEnd()); + + // Partial overlap: the longer span wins, the shorter overlapping one is dropped. + final List partial = NameFinderDL.mergeOverlappingSpans(new ArrayList<>(List.of( + new Span(0, 5, "LOC", 0.7), new Span(3, 12, "LOC", 0.6)))); + assertEquals(1, partial.size()); + assertEquals(3, partial.get(0).getStart()); + assertEquals(12, partial.get(0).getEnd()); + + // Equal length overlap: the higher probability wins the tie. + final List tie = NameFinderDL.mergeOverlappingSpans(new ArrayList<>(List.of( + new Span(0, 5, "PER", 0.6), new Span(2, 7, "PER", 0.9)))); + assertEquals(1, tie.size()); + assertEquals(2, tie.get(0).getStart()); + assertEquals(7, tie.get(0).getEnd()); + + // Adjacent but disjoint spans are both kept and returned in document order. + final List disjoint = NameFinderDL.mergeOverlappingSpans(new ArrayList<>(List.of( + new Span(5, 10, "LOC", 0.9), new Span(0, 5, "PER", 0.9)))); + assertEquals(2, disjoint.size()); + assertEquals(0, disjoint.get(0).getStart()); + assertEquals(5, disjoint.get(1).getStart()); + } + + @Test + void testMergeOverlappingSpansReturnsAFreshListForTrivialInput() { + // The size < 2 fast path must hand back a new list, not the caller's, so the result is always + // owned by the caller -- the same ownership contract as the merging path. + final List single = new ArrayList<>(List.of(new Span(0, 5, "PER", 0.9))); + final List merged = NameFinderDL.mergeOverlappingSpans(single); + assertEquals(1, merged.size()); + assertNotSame(single, merged); + assertTrue(NameFinderDL.mergeOverlappingSpans(new ArrayList()).isEmpty()); + } + + @Test + void testChunkAssemblyKeepsFullerOverlappingSpan() { + // Mirrors how locate() decodes two overlapping chunks bounded to their own character regions. + // Chunk 1 covers up to "York" and labels "New York"; the overlapping chunk 2 covers "New York + // City" and labels it. Both candidates are produced and mergeOverlappingSpans keeps the fuller + // "New York City" rather than dropping it, which is the chunk-boundary case a single forward + // cursor mishandled. + final String text = "Alice visited New York City."; // "New York" = [14,22), "City" ends at 27 + final String[] chunk1 = {"[CLS]", "New", "York", "[SEP]"}; + final float[][] scores1 = {scoresFor(0), scoresFor(3), scoresFor(4), scoresFor(0)}; + final String[] chunk2 = {"[CLS]", "New", "York", "City", "[SEP]"}; + final float[][] scores2 = + {scoresFor(0), scoresFor(3), scoresFor(4), scoresFor(4), scoresFor(0)}; + + final List candidates = new ArrayList<>(); + candidates.addAll(NameFinderDL.decodeSpans(text, chunk1, scores1, ID_TO_LABELS, 14, 22)); + candidates.addAll(NameFinderDL.decodeSpans(text, chunk2, scores2, ID_TO_LABELS, 14, 27)); + assertEquals(2, candidates.size(), "both chunks emit a candidate for the boundary entity"); + + final List merged = NameFinderDL.mergeOverlappingSpans(candidates); + assertEquals(1, merged.size()); + assertEquals("New York City", merged.get(0).getCoveredText(text)); + } + @Test void testDecodeSpansLocatesEntityWithRegexMetacharacters() { // WordPiece splits "C++" into C / + / + tokens, so the reconstructed span text contains regex @@ -390,6 +489,14 @@ void testLabelProbabilityIgnoresNaNInDenominator() { assertBounded(p); } + @Test + void testBuildSpanTextSkipsRobertaSpecialTokens() { + // RoBERTa markers (, ) must be skipped during span reconstruction, the same way the + // BERT [CLS]/[SEP] markers are, so they never leak into a reconstructed entity span. + assertEquals("New York", + NameFinderDL.buildSpanText(new String[] {"", "New", "York", ""}, 0, 3)); + } + private static float[] scoresFor(int labelIndex) { final float[] scores = new float[ID_TO_LABELS.size()]; for (int i = 0; i < scores.length; i++) { diff --git a/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java b/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java index bf7002a06..5778b7438 100644 --- a/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java +++ b/opennlp-eval-tests/src/test/java/opennlp/dl/doccat/DocumentCategorizerDLEval.java @@ -239,7 +239,7 @@ public void categorizeWithGpu() throws Exception { try (final DocumentCategorizerDL documentCategorizerDL = new DocumentCategorizerDL(model, vocab, getCategories(), new AverageClassificationScoringStrategy(), - new InferenceOptions())) { + inferenceOptions)) { final double[] result = documentCategorizerDL.categorize(new String[] {"I am happy"}); logger.debug(Arrays.toString(result)); diff --git a/opennlp-eval-tests/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java b/opennlp-eval-tests/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java index e19742b96..f06d03a2b 100644 --- a/opennlp-eval-tests/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java +++ b/opennlp-eval-tests/src/test/java/opennlp/dl/namefinder/NameFinderDLEval.java @@ -407,6 +407,157 @@ public void invalidModel() { } + /** + * End-to-end offset safety: with dash normalization enabled and a non-BMP dash before an entity, + * the fold shrinks the text by one UTF-16 unit, so the entity sits at a smaller offset in the + * normalized text than in the original. {@link NameFinderDL#findInOriginal(String[])} must report + * the entity at its true offset in the original input, not the one-unit-shorter normalized offset. + */ + @Test + public void findInOriginalMapsSpansAcrossNonBmpDash() throws Exception { + + final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx"); + final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt"); + + final InferenceOptions options = new InferenceOptions(); + options.setNormalizeDashes(true); + + // Yezidi hyphen (U+10EAD): a Unicode Dash code point in the supplementary plane (two UTF-16 + // units) that folds to a single ASCII hyphen, shifting every following character left by one. + final String yezidi = new String(Character.toChars(0x10EAD)); + final String[] tokens = {yezidi, "George", "Washington", "was", "president", + "of", "the", "United", "States", "."}; + final String original = String.join(" ", tokens); + + try (final NameFinderDL nameFinderDL = + new NameFinderDL(model, vocab, getIds2Labels(), options, sentenceDetector)) { + + final Span[] spans = nameFinderDL.findInOriginal(tokens); + + Span person = null; + for (final Span span : spans) { + if ("PER".equals(span.getType())) { + person = span; + } + } + Assertions.assertNotNull(person, "the PER entity should still be detected after the dash"); + // Mapped back through the alignment, the span covers the entity in the ORIGINAL input (which + // still contains the two-unit dash); without the mapping it would be shifted left by one. + Assertions.assertEquals("George Washington", person.getCoveredText(original)); + Assertions.assertEquals(original.indexOf("George Washington"), person.getStart()); + } + + } + + /** + * End-to-end chunk-boundary safety with the real model. With a small split size and an overlap, + * "United States" straddles a chunk boundary: the first chunk ends at "United" while the + * overlapping next chunk sees the full entity. The decoder must keep the complete "United States" + * span rather than a truncated copy from whichever chunk decoded first, and the output spans must + * not overlap. + */ + @Test + public void findKeepsFullEntityAcrossChunkBoundary() throws Exception { + + final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx"); + final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt"); + + final InferenceOptions options = new InferenceOptions(); + // chunk 1 = tokens [0,7) (ends at "United"); chunk 2 = tokens [5,9) ("the United States ."). + options.setDocumentSplitSize(7); + options.setSplitOverlapSize(2); + + final String[] tokens = {"George", "Washington", "was", "president", + "of", "the", "United", "States", "."}; + final String text = String.join(" ", tokens); + + try (final NameFinderDL nameFinderDL = + new NameFinderDL(model, vocab, getIds2Labels(), options, sentenceDetector)) { + + final Span[] spans = nameFinderDL.find(tokens); + + boolean fullLocation = false; + boolean person = false; + for (final Span span : spans) { + if ("LOC".equals(span.getType()) && "United States".equals(span.getCoveredText(text))) { + fullLocation = true; + } + if ("PER".equals(span.getType()) && "George Washington".equals(span.getCoveredText(text))) { + person = true; + } + } + Assertions.assertTrue(fullLocation, + "the full 'United States' entity must survive the chunk boundary, not a truncated copy"); + Assertions.assertTrue(person, "the PER entity should still be detected"); + + assertNoOverlappingSpans(spans, text); + } + + } + + /** + * End-to-end de-duplication with the real model. With a generous overlap, "United States" sits + * entirely within the shared region of two chunks and is decoded by both. The decoder must report + * it once, not twice. + */ + @Test + public void findDeduplicatesEntityDecodedInBothOverlappingChunks() throws Exception { + + final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx"); + final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt"); + + final InferenceOptions options = new InferenceOptions(); + // chunk 1 = tokens [0,8); chunk 2 = tokens [4,9); "United States" (tokens 6,7) is in both. + options.setDocumentSplitSize(8); + options.setSplitOverlapSize(4); + + final String[] tokens = {"George", "Washington", "was", "president", + "of", "the", "United", "States", "."}; + final String text = String.join(" ", tokens); + + try (final NameFinderDL nameFinderDL = + new NameFinderDL(model, vocab, getIds2Labels(), options, sentenceDetector)) { + + final Span[] spans = nameFinderDL.find(tokens); + + int locations = 0; + for (final Span span : spans) { + if ("LOC".equals(span.getType()) && "United States".equals(span.getCoveredText(text))) { + locations++; + } + } + Assertions.assertEquals(1, locations, + "the entity in the chunk overlap must be reported once, not duplicated"); + + assertNoOverlappingSpans(spans, text); + } + + } + + // Asserts no two spans overlap in character coordinates, the post-merge invariant of locate(). + private static void assertNoOverlappingSpans(final Span[] spans, final String text) { + for (int i = 0; i < spans.length; i++) { + for (int j = i + 1; j < spans.length; j++) { + Assertions.assertFalse(spans[i].intersects(spans[j]), + "spans must not overlap after merge: " + spans[i].getCoveredText(text) + + " vs " + spans[j].getCoveredText(text)); + } + } + } + + @Test + public void findRejectsNullInput() throws Exception { + // Public entry points fail fast on a null token array rather than deeper inside String.join. + final File model = new File(getOpennlpDataDir(), "onnx/namefinder/model.onnx"); + final File vocab = new File(getOpennlpDataDir(), "onnx/namefinder/vocab.txt"); + + try (final NameFinderDL nameFinderDL = + new NameFinderDL(model, vocab, getIds2Labels(), sentenceDetector)) { + Assertions.assertThrows(NullPointerException.class, () -> nameFinderDL.find(null)); + Assertions.assertThrows(NullPointerException.class, () -> nameFinderDL.findInOriginal(null)); + } + } + private Map getIds2Labels() { final Map ids2Labels = new HashMap<>();