Skip to content

Chat Template

Chat template utilities for parsing and managing conversation templates.

Chat Template Parser

Utilities for parsing and managing chat templates, including support for different model formats and conversation structures.

Features

  • Parse various chat template formats
  • Handle different conversation structures
  • Support for custom template parsing

rllm.parser.chat_template_parser

ChatTemplateParser

Source code in rllm/parser/chat_template_parser.py
class ChatTemplateParser:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.generation_prompt = self._get_generation_prompt(tokenizer)

    def _get_generation_prompt(self, tokenizer):
        messages = [{"role": "assistant", "content": ""}]

        with_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        without_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)

        generation_prompt = with_prompt[len(without_prompt) :]

        return generation_prompt

    def parse(self, messages, add_generation_prompt=False, is_first_msg=False, **kwargs) -> str:
        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)

    def verify_equivalence(self, messages, verbose=True):
        """Verify that parsing messages together is equivalent to parsing them individually.

        Args:
            messages (list): List of message dictionaries to test
            verbose (bool): Whether to print detailed information about the test

        Returns:
            bool: True if the equivalence check passes, False otherwise

        Raises:
            AssertionError: If the equivalence check fails and verbose is True
        """
        # Parse all messages together
        batch_result = self.parse(messages)

        # Parse each message individually and concatenate
        individual_results = []
        for message in messages:
            individual_results.append(self.parse([message]))

        concatenated_result = "".join(individual_results)

        # Check if results are equivalent
        is_equivalent = batch_result == concatenated_result

        if verbose and not is_equivalent:
            print("Equivalence check failed!")
            print("Batch parsing result:")
            print(batch_result)
            print("\nConcatenated individual parsing result:")
            print(concatenated_result)
            raise AssertionError("Parser failed equivalence check. See above for details.")

        return is_equivalent

    @classmethod
    def get_parser(cls, tokenizer, disable_thinking=False) -> "ChatTemplateParser":
        """Factory method to get the appropriate parser based on a string identifier.

        Args:
            parser_type (str): String identifier for the parser type
            tokenizer: The tokenizer to use with the parser
            disable_thinking: Whether generation prompt will disable thinking.

        Returns:
            ChatTemplateParser: An instance of the requested parser

        Raises:
            ValueError: If the parser_type is not recognized
        """
        # Determine parser type based on tokenizer name or path
        if isinstance(tokenizer.name_or_path, str):
            model_name = tokenizer.name_or_path.lower()
            tokenizer_cls = tokenizer.__class__.__name__.lower()
            logger.info(f"model_name: {model_name}, tokenizer_cls: {tokenizer_cls}")
            if any(x in model_name for x in ("deepseek", "deepscaler", "deepcoder")) and "llama" in tokenizer_cls:
                logger.info(f"Using DeepseekQwenChatTemplateParser for {tokenizer.name_or_path}")
                return DeepseekQwenChatTemplateParser(tokenizer)
            elif "qwen" in model_name or "r2e" in model_name or "deepswe" in model_name or "qwen" in tokenizer_cls:
                logger.info(f"Using QwenChatTemplateParser for {tokenizer.name_or_path}")
                return QwenChatTemplateParser(tokenizer, disable_thinking=disable_thinking)
            elif "llama" in model_name:
                logger.info(f"Using LlamaChatTemplateParser for {tokenizer.name_or_path}")
                return LlamaChatTemplateParser(tokenizer)

        # Default to the standard parser if no specific match
        parser = ChatTemplateParser(tokenizer)
        logger.info(f"No custom parser found. Using default ChatTemplateParser for {tokenizer.name_or_path}")
        assert parser.verify_equivalence(PARSER_TEST_MESSAGES), "Parser failed equivalence check"
        return parser

    def tokenize_and_mask(self, messages):
        try:
            last_assistant_idx = max(i for i, msg in enumerate(messages) if msg["role"] == "assistant")
        except ValueError:
            raise ValueError("No assistant message found in chat_completions") from None

        prompt = self.parse(messages[:last_assistant_idx], is_first_msg=True, add_generation_prompt=True, accumulate_reasoning=False)
        prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)

        response = self.parse([messages[last_assistant_idx]], is_first_msg=False, add_generation_prompt=False, accumulate_reasoning=True)
        response = response[len(self.generation_prompt) :].rstrip("\n")  # handle qwen trailing newline from eot token
        response_ids = self.tokenizer.encode(response, add_special_tokens=False)
        response_mask = [1] * len(response_ids)

        prompt_ids = torch.tensor(prompt_ids, dtype=torch.long)
        response_ids = torch.tensor(response_ids, dtype=torch.long)
        response_mask = torch.tensor(response_mask, dtype=torch.long)

        return prompt_ids, response_ids, response_mask

    def tokenize_and_mask_cumulative(self, messages):
        response_ids = []
        response_mask = []

        try:
            first_assistant_idx = next(i for i, msg in enumerate(messages) if msg["role"] == "assistant")
        except StopIteration:
            raise ValueError("No assistant message found in chat_completions") from None

        prompt = self.parse(messages[:first_assistant_idx], is_first_msg=True, add_generation_prompt=True, accumulate_reasoning=False)
        prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)

        for i in range(first_assistant_idx, len(messages)):
            is_asst = messages[i]["role"] == "assistant"
            if is_asst:
                response = self.parse([messages[i]], is_first_msg=False, add_generation_prompt=False, accumulate_reasoning=True)
                response = response[len(self.generation_prompt) :]
                ids = self.tokenizer.encode(response, add_special_tokens=False)
                response_ids.extend(ids)
                response_mask.extend([1] * len(ids))
            else:
                response = self.parse([messages[i]], is_first_msg=False, add_generation_prompt=True, accumulate_reasoning=False)
                ids = self.tokenizer.encode(response, add_special_tokens=False)
                response_ids.extend(ids)
                response_mask.extend([0] * len(ids))

        prompt_ids = torch.tensor(prompt_ids, dtype=torch.long)
        response_ids = torch.tensor(response_ids, dtype=torch.long)
        response_mask = torch.tensor(response_mask, dtype=torch.long)

        return prompt_ids, response_ids, response_mask

verify_equivalence

verify_equivalence(messages, verbose=True)

Verify that parsing messages together is equivalent to parsing them individually.

Parameters:

Name Type Description Default
messages list

List of message dictionaries to test

required
verbose bool

Whether to print detailed information about the test

True

Returns:

Name Type Description
bool

True if the equivalence check passes, False otherwise

Raises:

Type Description
AssertionError

If the equivalence check fails and verbose is True

Source code in rllm/parser/chat_template_parser.py
def verify_equivalence(self, messages, verbose=True):
    """Verify that parsing messages together is equivalent to parsing them individually.

    Args:
        messages (list): List of message dictionaries to test
        verbose (bool): Whether to print detailed information about the test

    Returns:
        bool: True if the equivalence check passes, False otherwise

    Raises:
        AssertionError: If the equivalence check fails and verbose is True
    """
    # Parse all messages together
    batch_result = self.parse(messages)

    # Parse each message individually and concatenate
    individual_results = []
    for message in messages:
        individual_results.append(self.parse([message]))

    concatenated_result = "".join(individual_results)

    # Check if results are equivalent
    is_equivalent = batch_result == concatenated_result

    if verbose and not is_equivalent:
        print("Equivalence check failed!")
        print("Batch parsing result:")
        print(batch_result)
        print("\nConcatenated individual parsing result:")
        print(concatenated_result)
        raise AssertionError("Parser failed equivalence check. See above for details.")

    return is_equivalent

get_parser classmethod

get_parser(tokenizer, disable_thinking=False) -> ChatTemplateParser

Factory method to get the appropriate parser based on a string identifier.

Parameters:

Name Type Description Default
parser_type str

String identifier for the parser type

required
tokenizer

The tokenizer to use with the parser

required
disable_thinking

Whether generation prompt will disable thinking.

False

Returns:

Name Type Description
ChatTemplateParser ChatTemplateParser

An instance of the requested parser

Raises:

Type Description
ValueError

If the parser_type is not recognized

Source code in rllm/parser/chat_template_parser.py
@classmethod
def get_parser(cls, tokenizer, disable_thinking=False) -> "ChatTemplateParser":
    """Factory method to get the appropriate parser based on a string identifier.

    Args:
        parser_type (str): String identifier for the parser type
        tokenizer: The tokenizer to use with the parser
        disable_thinking: Whether generation prompt will disable thinking.

    Returns:
        ChatTemplateParser: An instance of the requested parser

    Raises:
        ValueError: If the parser_type is not recognized
    """
    # Determine parser type based on tokenizer name or path
    if isinstance(tokenizer.name_or_path, str):
        model_name = tokenizer.name_or_path.lower()
        tokenizer_cls = tokenizer.__class__.__name__.lower()
        logger.info(f"model_name: {model_name}, tokenizer_cls: {tokenizer_cls}")
        if any(x in model_name for x in ("deepseek", "deepscaler", "deepcoder")) and "llama" in tokenizer_cls:
            logger.info(f"Using DeepseekQwenChatTemplateParser for {tokenizer.name_or_path}")
            return DeepseekQwenChatTemplateParser(tokenizer)
        elif "qwen" in model_name or "r2e" in model_name or "deepswe" in model_name or "qwen" in tokenizer_cls:
            logger.info(f"Using QwenChatTemplateParser for {tokenizer.name_or_path}")
            return QwenChatTemplateParser(tokenizer, disable_thinking=disable_thinking)
        elif "llama" in model_name:
            logger.info(f"Using LlamaChatTemplateParser for {tokenizer.name_or_path}")
            return LlamaChatTemplateParser(tokenizer)

    # Default to the standard parser if no specific match
    parser = ChatTemplateParser(tokenizer)
    logger.info(f"No custom parser found. Using default ChatTemplateParser for {tokenizer.name_or_path}")
    assert parser.verify_equivalence(PARSER_TEST_MESSAGES), "Parser failed equivalence check"
    return parser