Building a custom Java tokenizer

Hi,

I am trying to build a custom tokenizer for tokenizing Java code using the tokenizers library. I have a set of tokens that should not be splitted into subwords (For example: Java keywords, operators, separators, common class names, etc). But I want identifiers in the Java token to split into subword tokens (For example: getAge, setName, etc).

To achieve this, I have added the tokens that don’t need to be splitted to the tokenizer before training. Then the tokenizer is trained on a large Java code file. Below is the code for this:

tokenizer = Tokenizer(models.WordPiece(unk_token='<unk>'))
tokenizer.pre_tokenizer = PreTokenizer.custom(JavaPreTokenizer())

with open(IMPORTANT_TOKENS_FILE) as fp:
    important_tokens = json.load(fp)

special_tokens = ["</s>", "<unk>", "<pad>"]
num = tokenizer.add_tokens(important_tokens + special_tokens)
# print(f"{num} tokens added")

trainer = WordPieceTrainer(vocab_size=32000, special_tokens=special_tokens)
tokenizer.decoder = decoders.WordPiece()
tokenizer.train([DATASET_FILE], trainer=trainer)
print(f"Vocab size: {tokenizer.get_vocab_size()}")

The custom pre-tokenizer is implemented using the javalang library. It is a python library for parsing and tokenizing Java code. Below is the custom pre-tokenizer:

class JavaPreTokenizer:

    def java_tokenize(self, i, normalized_string):
        string = str(normalized_string)
        javalang_tokens = list(javalang.tokenizer.tokenize(string))
        splits = []
        original_pos = 0
        for javalang_token in javalang_tokens:
            length = len(javalang_token.value)
            while str(javalang_token.value) != string[original_pos:original_pos+length] and original_pos < len(string):
                original_pos += 1
            if original_pos >= len(string):
                raise ValueError(f"Could not find token \"{javalang_token.value}\" in string \"{string}\"")
            splits.append(normalized_string[original_pos:original_pos+length])
            original_pos += length
        return splits

    def pre_tokenize(self, pretok: PreTokenizedString):
        pretok.split(self.java_tokenize)

The issue I am facing is, when I try to tokenize “class Person implements Comparable<Person> {\n String name = \"Hello World!\";”, the pre-tokenizer is getting only the "Hello Worl to pre-tokenize and fails with an error message Exception: LexerError: Unterminated character/string literal at " ", line 1: "Hello Worl.

The failure is understandable because javalang can not tokenize a unterminated string. But why this happens? I mean why pre-tokenizer is given only part of the input?

After some debugging, I saw that the pre-tokenizer is applied only on some parts of the input. Those parts are [' Person ', ' ', 'Person', ' ', '\n ', ' name ', ' "Hello World']. I thought the pre-tokenizer is applied to the whole input.

When I checked my vocabulary, I saw those tokens which were passed to pre-tokenizer are not in the added tokens section of the saved tokenizer (I saved the tokenizer after wrapping it using the transformers library). It seems like the tokens only which are not in the added tokens are passed to the pre-tokenizer.

How to fix this issue?