[
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n\n# Parallel Scaling Law for Language Model\n\n\n_Yet Another Scaling Law beyond Parameters and Inference Time Scaling_\n\n[![Paper](https://img.shields.io/badge/arXiv-2505.10475-red)](https://arxiv.org/abs/2505.10475)\n[![huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-FFD21E)](https://huggingface.co/ParScale)\n\n<div align=\"center\">\n<img src=\"figures/logo.jpg\" style=\"width: 10%;\" />\n</div>\n\n\n<p align=\"center\">\n    💡&nbsp;<a href=\"#-key-findings\">Key Findings</a>\n    | 📈&nbsp;<a href=\"#-scaling-law\">Scaling Law</a>\n    | ⚡&nbsp;<a href=\"#-cost-analysis\">Cost Analysis</a>\n    | 🔥&nbsp;<a href=\"#-models\">Models</a>\n    | 📚&nbsp;<a href=\"#-citation\">Citation</a>\n</p>\n</div>\n\n## 🌟 About\n\n- Most believe that scaling language models requires a heavy cost in either **space** (parameter scaling) or **time** (inference-time scaling). \n- We introduce the *third* scaling paradigm for scaling LLMs: leverages **parallel computation** during both training and inference time (Parallel Scaling, or *ParScale*).\n- We apply $P$ diverse and learnable transformations to the input, execute forward passes of the model in parallel, and dynamically aggregate the $P$ outputs. \n<div align=\"center\">\n<img src=\"figures/teaser.png\" style=\"width: 80%;\" />\n</div>\n\n---\n\n## 💡 Key Findings\n<div align=\"center\">\n<img src=\"figures/scaling_comparison.png\" style=\"width: 80%;\" />\n</div>\n\nHere are the core insights and benefits distilled from our theoretical analysis and empirical evaluations:\n\n📈 **Logarithmic Scaling Law**: We theoretically and empirically establish that **scaling with $P$ parallel streams is comparable to scaling the number of parameters by** $O(\\log P)$. This suggests that parallel computation can serve as an efficient substitute for parameter growth, especially for larger models.\n\n✅ **Universal Applicability**: Unlike inference-time scaling which requires specialized data and limited application, it works with any model architecture, optimization method, data, or downstream task.\n\n\n🧠 **Stronger Performance on Reasoning Tasks**: Reasoning-intensive tasks (e.g., coding or math) benefit more from ParScale, which suggests that scaling computation can effectively push the boundary of reasoning. \n\n⚡ **Superior Inference Efficiency**: ParScale can use up to **22x less memory increase** and **6x less latency increase** compared to parameter scaling that achieves the same performance improvement (batch size=1).\n\n🧱 **Cost-Efficient Training via Two-Stage Strategy**: Training a parallel-scaled model doesn't require starting from scratch. With a two-stage training strategy, we can post-train ithe parallel components using only a small amount of data.\n\n🔁 **Dynamic Adaptation at Inference Time**: We find that ParScale remains effective with frozen main parameters for different $P$. This illustrates the potential of dynamic parallel scaling: switching $P$ to dynamically adapt model capabilities during inference.\n\nWe release the inference code in `modeling_qwen2_parscale.py` and `configuration_qwen2_parscale.py`. Our 67 checkpoints is available at [🤗 HuggingFace](https://huggingface.co/ParScale).\n\n---\n\n## 📈 Scaling Law\n\n- We carry out large-scale pre-training experiments on the Stack-V2 and Pile corpus, by ranging $P$ from 1 to 8 and model parameters from 500M to 4.4B. \n- We use the results to fit a new *parallel scaling law* that generalizes the Chinchilla scaling law.\n- We release our parametric fitting code in `parametric_fit.py`.\n- Feel free to try [🤗 HuggingFace Space](https://huggingface.co/spaces/ParScale/Parallel_Scaling_Law) for a nice visualization for the parallel scaling law!\n<div align=\"center\">\n<img src=\"figures/scaling_law.png\" style=\"width: 70%;\" />\n<img src=\"figures/scaling_law2.png\" style=\"width: 70%;\" />\n</div>\n\n---\n\n## ⚡ Cost Analysis\n\n<div align=\"center\">\n<img src=\"figures/cost.png\" style=\"width: 70%;\" />\n</div>\n\n- We further compare the inference efficiency between parallel scaling and parameter scaling at equivalent performance levels. \n- We release our analysis code in `cost_analysis.py`. Before using it, you should first install [llm-analysis](https://github.com/cli99/llm-analysis):\n\n```bash\ngit clone https://github.com/cli99/llm-analysis.git\ncd llm-analysis\npip install .\n```\n\n- You can use the following command to analyze the inference memory and latency cost for our 4.4B model, with $P=2$ and batch size=2:\n```bash\npython cost_analysis.py --hidden_size 2560 --intermediate_size 13824 --P 2 --batch_size 2\n```\n\n---\n\n## 🔥 Models\n\n✨ are our recommendation for strong models!\n\n### Base models for scaling training data to 1T tokens\n\nThese models demonstrate strong competitiveness among existing small models, including SmolLM, gemma, and Llama-3.2.\n\n|Model|Description|Download|\n|:-:|:-:|:-:|\n|ParScale-1.8B-P1|✨ Baseline $P=1$|[🤗 ParScale/ParScale-1.8B-P1](https://huggingface.co/ParScale/ParScale-1.8B-P1)|\n|ParScale-1.8B-P2|✨ ParScale $P=2$|[🤗 ParScale/ParScale-1.8B-P2](https://huggingface.co/ParScale/ParScale-1.8B-P2)|\n|ParScale-1.8B-P4|✨ ParScale $P=4$|[🤗 ParScale/ParScale-1.8B-P4](https://huggingface.co/ParScale/ParScale-1.8B-P4)|\n|ParScale-1.8B-P8|✨ ParScale $P=8$|[🤗 ParScale/ParScale-1.8B-P8](https://huggingface.co/ParScale/ParScale-1.8B-P8)|\n\n### Instruct models for scaling training data to 1T tokens\n\nWe post-trained the aforementioned base model on SmolTalk-1M to enable conversational capabilities.\n\n|Model|Description|Download|\n|:-:|:-:|:-:|\n|ParScale-1.8B-P1-Inst|✨ Baseline $P=1$|[🤗 ParScale/ParScale-1.8B-P1-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P1-Inst)|\n|ParScale-1.8B-P2-Inst|✨ ParScale $P=2$|[🤗 ParScale/ParScale-1.8B-P2-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P2-Inst)|\n|ParScale-1.8B-P4-Inst|✨ ParScale $P=4$|[🤗 ParScale/ParScale-1.8B-P4-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P4-Inst)|\n|ParScale-1.8B-P8-Inst|✨ ParScale $P=8$|[🤗 ParScale/ParScale-1.8B-P8-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P8-Inst)|\n\n\n### Continual Pretraining Qwen-2.5-3B\n\nWe froze the parameters of Qwen-2.5-3B and only fine-tuned the newly introduced parameters on Stack-V2-Python. Since the following models share the same backbone parameters as Qwen-2.5-3B, they have the potential for dynamic ParScale: switching P to adapt model capabilities during inference.\n\n|Model|Description|Download|\n|:-:|:-:|:-:|\n|ParScale-Qwen-3B-P2-Python|✨ ParScale $P=2$|[🤗 ParScale/ParScale-Qwen-3B-P2-Python](https://huggingface.co/ParScale/ParScale-Qwen-3B-P2-Python)|\n|ParScale-Qwen-3B-P4-Python|✨ ParScale $P=4$|[🤗 ParScale/ParScale-Qwen-3B-P4-Python](https://huggingface.co/ParScale/ParScale-Qwen-3B-P4-Python)|\n|ParScale-Qwen-3B-P8-Python|✨ ParScale $P=8$|[🤗 ParScale/ParScale-Qwen-3B-P8-Python](https://huggingface.co/ParScale/ParScale-Qwen-3B-P8-Python)|\n\n- For full continual pretraining on Stack-V2-Python\n\n|Model|Description|Download|\n|:-:|:-:|:-:|\n|ParScale-QwenInit-3B-P1-Python|Baseline $P=1$|[🤗 ParScale/ParScale-QwenInit-3B-P1-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P1-Python)|\n|ParScale-QwenInit-3B-P2-Python|ParScale $P=2$|[🤗 ParScale/ParScale-QwenInit-3B-P2-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P2-Python)|\n|ParScale-QwenInit-3B-P4-Python|ParScale $P=4$|[🤗 ParScale/ParScale-QwenInit-3B-P4-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P4-Python)|\n|ParScale-QwenInit-3B-P8-Python|ParScale $P=8$|[🤗 ParScale/ParScale-QwenInit-3B-P8-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P8-Python)|\n\n- For full continual pretraining on Pile\n\n|Model|Description|Download|\n|:-:|:-:|:-:|\n|ParScale-QwenInit-3B-P1-Pile|Baseline $P=1$|[🤗 ParScale/ParScale-QwenInit-3B-P1-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P1-Pile)|\n|ParScale-QwenInit-3B-P2-Pile|ParScale $P=2$|[🤗 ParScale/ParScale-QwenInit-3B-P2-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P2-Pile)|\n|ParScale-QwenInit-3B-P4-Pile|ParScale $P=4$|[🤗 ParScale/ParScale-QwenInit-3B-P4-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P4-Pile)|\n|ParScale-QwenInit-3B-P8-Pile|ParScale $P=8$|[🤗 ParScale/ParScale-QwenInit-3B-P8-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P8-Pile)|\n\n\n### Checkpoints Used to Fit the Scaling Law\n\nDownload link: https://huggingface.co/ParScale/ParScale-{size}-{P}-{dataset}\n\n- {size}: model size, from {0.7B, 0.9B, 1.3B, 1.8B, 3B, 4.7B}\n- {P}: number of parallels, from {P1, P2, P4, P8}\n- {dataset}: training dataset, from {Python, Pile}\n- $6\\times 4 \\times 2=48$ checkpoints in total.\n\n### Usage Example with 🤗 Hugging Face\n\n```python\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nname = \"ParScale/ParScale-1.8B-P8\" # or anything else you like\nmodel = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True).to(\"cuda\")\ntokenizer = AutoTokenizer.from_pretrained(name)\ninputs = tokenizer.encode(\"Hello, how are you today?\", return_tensors=\"pt\").to(\"cuda\")\noutputs = model.generate(inputs, max_new_tokens=128)[0]\nprint(tokenizer.decode(outputs))\n```\n\n\n## 📚 Citation\n\n```bibtex\n@article{ParScale,\n      title={Parallel Scaling Law for Language Models}, \n      author={Mouxiang Chen and Binyuan Hui and Zeyu Cui and Jiaxi Yang and Dayiheng Liu and Jianling Sun and Junyang Lin and Zhongxin Liu},\n      year={2025},\n      eprint={2505.10475},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG},\n      journal={arXiv preprint arXiv:2505.10475},\n      url={https://arxiv.org/abs/2505.10475}, \n}\n```\n"
  },
  {
    "path": "configuration_qwen2_parscale.py",
    "content": "\"\"\"Qwen2 model configuration, with support for ParScale\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_rope_utils import rope_config_validation\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass Qwen2ParScaleConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a\n    Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of\n    Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 151936):\n            Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`Qwen2Model`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 22016):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 32):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 32768):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (`float`, *optional*):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (`int`, *optional*):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (`float`, *optional*):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        use_sliding_window (`bool`, *optional*, defaults to `False`):\n            Whether to use sliding window attention.\n        sliding_window (`int`, *optional*, defaults to 4096):\n            Sliding window attention (SWA) window size. If not specified, will default to `4096`.\n        max_window_layers (`int`, *optional*, defaults to 28):\n            The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n\n    ```python\n    >>> from transformers import Qwen2Model, Qwen2Config\n\n    >>> # Initializing a Qwen2 style configuration\n    >>> configuration = Qwen2Config()\n\n    >>> # Initializing a model from the Qwen2-7B style configuration\n    >>> model = Qwen2Model(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"qwen2_parscale\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    # Default tensor parallel plan for base model `Qwen2`\n    base_model_tp_plan = {\n        \"layers.*.self_attn.q_proj\": \"colwise\",\n        \"layers.*.self_attn.k_proj\": \"colwise\",\n        \"layers.*.self_attn.v_proj\": \"colwise\",\n        \"layers.*.self_attn.o_proj\": \"rowwise\",\n        \"layers.*.mlp.gate_proj\": \"colwise\",\n        \"layers.*.mlp.up_proj\": \"colwise\",\n        \"layers.*.mlp.down_proj\": \"rowwise\",\n    }\n\n    def __init__(\n        self,\n        vocab_size=151936,\n        hidden_size=4096,\n        intermediate_size=22016,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        use_sliding_window=False,\n        sliding_window=4096,\n        max_window_layers=28,\n        attention_dropout=0.0,\n        parscale_n=1,\n        parscale_n_tokens=48,\n        parscale_attn_smooth=0.01,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.use_sliding_window = use_sliding_window\n        self.sliding_window = sliding_window if use_sliding_window else None\n        self.max_window_layers = max_window_layers\n        self.parscale_n = parscale_n\n        self.parscale_n_tokens = parscale_n_tokens\n        self.parscale_attn_smooth = parscale_attn_smooth\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_dropout = attention_dropout\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n        rope_config_validation(self)\n\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n"
  },
  {
    "path": "cost_analysis.py",
    "content": "import numpy as np\nimport json\nimport os\nfrom llm_analysis.analysis import LLMAnalysis, get_gpu_config_by_name, ModelConfig, ActivationRecomputation, BYTES_FP16\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n\n    # General model config\n    parser.add_argument('--hidden_size', type=int, required=True)\n    parser.add_argument('--intermediate_size', type=int, required=True)\n    parser.add_argument('--num_hidden_layers', type=int, default=36)\n    parser.add_argument('--num_attention_heads', type=int, default=16)\n    parser.add_argument('--max_position_embeddings', type=int, default=2048)\n    parser.add_argument('--num_key_value_heads', type=int, default=2)\n    parser.add_argument('--vocab_size', type=int, default=151936)\n\n    # Parscale config\n    parser.add_argument('--P', type=int, default=1) # Number of parallel streams\n    parser.add_argument('--parscale_prefix_tokens', type=int, default=48) # Number of prefix tokens\n\n    # Data config\n    parser.add_argument('--batch_size', type=int, default=1)\n    parser.add_argument('--input_length', type=int, default=64)\n    parser.add_argument('--output_length', type=int, default=64)\n\n    # GPU config\n    parser.add_argument('--gpu_config', type=str, default=\"a100-sxm-80gb\")\n    parser.add_argument('--flops_efficiency', type=float, default=0.7) # Recommended by llm-analysis\n    parser.add_argument('--hbm_memory_efficiency', type=float, default=0.9) # Recommended by llm-analysis\n\n    args = parser.parse_args()\n    p = args.P\n    model_config = ModelConfig(\n        name=\"\", \n        num_layers=args.num_hidden_layers, \n        n_head=args.num_attention_heads, \n        hidden_dim=args.hidden_size, vocab_size=args.vocab_size, \n        max_seq_len=args.max_position_embeddings + (args.parscale_prefix_tokens if p > 1 else 0), \n        num_key_value_heads=args.num_key_value_heads, \n        ffn_embed_dim=args.intermediate_size, \n        mlp_gated_linear_units=True\n    )\n    gpu_config = get_gpu_config_by_name(\"a100-sxm-80gb\")\n    gpu_config.mem_per_GPU_in_GB = 10000\n\n    analysis = LLMAnalysis(\n        model_config,\n        gpu_config,\n        flops_efficiency=0.7,\n        hbm_memory_efficiency=0.9,\n    )\n    seq_len = args.input_length + (args.parscale_prefix_tokens if p > 1 else 0)\n    summary_dict = analysis.inference(\n        batch_size_per_gpu=args.batch_size * p,\n        seq_len=seq_len,\n        num_tokens_to_generate=args.output_length,\n    )\n\n    # We consider the influence of the aggregation layer. \n    aggregate_param = (args.hidden_size + 1) * args.hidden_size * p if p > 1 else 0\n    aggregate_param_vs_fwd_param = aggregate_param / analysis.get_num_params_per_layer_mlp()\n    aggregate_latency = aggregate_param_vs_fwd_param * analysis.get_latency_fwd_per_layer_mlp(args.batch_size, args.input_length + args.output_length) if p > 1 else 0\n    aggregate_memory = aggregate_param * analysis.dtype_config.weight_bits / 8\n\n    prefill_activation_memory_per_gpu = max(\n        # Each layer's activation memory will increase by P times\n        analysis.get_activation_memory_per_layer(\n            args.batch_size * p,\n            seq_len,\n            is_inference=True,\n            layernorm_dtype_bytes=BYTES_FP16,\n        ),\n        # The embedding's activation memory will not participate in parallel and independent of P.\n        analysis.get_activation_memory_output_embedding(\n            args.batch_size, seq_len\n        )\n    )\n\n    # Since we use batch_size * p as the new batch size, the latency for llm-analysis assumes the embedding latency is also computed in this new batch size. However, ParScale will not increase the computation for embedding.\n    # Therefore, we should make a fix toward it. \n    embedding_latency_estimate_for_embedding = (\n        analysis.get_latency_fwd_input_embedding(args.batch_size * p, args.input_length + args.output_length, dtype_bytes=analysis.dtype_config.embedding_bits) + \n        analysis.get_latency_fwd_output_embedding_loss(args.batch_size * p, args.input_length + args.output_length)\n    )\n    embedding_latency_real_for_embedding = (\n        analysis.get_latency_fwd_input_embedding(args.batch_size, args.input_length + args.output_length, dtype_bytes=analysis.dtype_config.embedding_bits) + \n        analysis.get_latency_fwd_output_embedding_loss(args.batch_size, args.input_length + args.output_length)\n    )\n\n    total_memory = (\n        summary_dict['kv_cache_memory_per_gpu'] + \n        summary_dict['weight_memory_per_gpu'] + \n        aggregate_memory + \n        prefill_activation_memory_per_gpu\n    )\n    total_latency = (\n        summary_dict['total_latency'] + aggregate_latency\n        - embedding_latency_estimate_for_embedding\n        + embedding_latency_real_for_embedding\n    )\n    print(f\"Memory: {total_memory / 2**30:.3f}GB; Latency: {total_latency:.3f}s\")"
  },
  {
    "path": "modeling_qwen2_parscale.py",
    "content": "\"\"\"\nThis is the inference code for ParScale, Based on Qwen2. It can be used directly to load existing Qwen2 models (setting parscale_n = 1 by default).\nAll modifications are wrapped within the condition 'parscale_n > 1'. \nIf you are interested in how ParScale is implemented, please search for \"parscale_n\" in this file.\n\"\"\"\n\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\nfrom einops import repeat, rearrange\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.generation import GenerationMixin\nfrom transformers.modeling_attn_mask_utils import AttentionMaskConverter\nfrom transformers.modeling_flash_attention_utils import FlashAttentionKwargs\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS\nfrom transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel\nfrom transformers.processing_utils import Unpack\nfrom transformers.utils import (\n    LossKwargs,\n    add_code_sample_docstrings,\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    logging,\n    replace_return_docstrings,\n)\nfrom .configuration_qwen2_parscale import Qwen2ParScaleConfig\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"meta-qwen2/Qwen2-2-7b-hf\"\n_CONFIG_FOR_DOC = \"Qwen2ParScaleConfig\"\n\n\nclass Qwen2MLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    scaling: float,\n    dropout: float = 0.0,\n    **kwargs,\n):\n    key_states = repeat_kv(key, module.num_key_value_groups)\n    value_states = repeat_kv(value, module.num_key_value_groups)\n\n    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n    if attention_mask is not None:\n        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n        attn_weights = attn_weights + causal_mask\n\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)\n    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)\n    attn_output = torch.matmul(attn_weights, value_states)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, attn_weights\n\nclass ParscaleCache(DynamicCache):\n    def __init__(self, prefix_k, prefix_v) -> None:\n        super().__init__()\n        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen\n        self.key_cache: List[torch.Tensor] = prefix_k\n        self.value_cache: List[torch.Tensor] = prefix_v\n        self.parscale_n = prefix_k[0].size(0)\n        self.n_prefix_tokens = prefix_k[0].size(2)\n    def update(\n        self,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        layer_idx: int,\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        if self.key_cache[layer_idx].size(0) != key_states.size(0):\n            # first time generation\n            self.key_cache[layer_idx] = repeat(self.key_cache[layer_idx], 'n_parscale ... -> (n_parscale b) ...', b=key_states.size(0) // self.parscale_n)\n            self.value_cache[layer_idx] = repeat(self.value_cache[layer_idx], 'n_parscale ... -> (n_parscale b) ...', b=key_states.size(0) // self.parscale_n)\n        return super().update(key_states, value_states, layer_idx, cache_kwargs)\n\n    def get_seq_length(self, layer_idx = 0):\n        seq_len = super().get_seq_length(layer_idx)\n        if seq_len != 0:\n            seq_len -= self.n_prefix_tokens\n        return seq_len\n\n    def reorder_cache(self, beam_idx: torch.LongTensor):\n        \"\"\"Reorders the cache for beam search, given the selected beam indices.\"\"\"\n        b = self.key_cache[0].size(0) // self.parscale_n\n        beam_idx = torch.cat([beam_idx + b * i for i in range(self.parscale_n)])\n        super().reorder_cache(beam_idx)\n\nclass Qwen2Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: Qwen2ParScaleConfig, layer_idx: int):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads\n        self.scaling = self.head_dim**-0.5\n        self.attention_dropout = config.attention_dropout\n        self.is_causal = True\n        self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)\n        self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)\n        self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)\n        self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)\n        if config.parscale_n > 1:\n            self.prefix_k = nn.Parameter(torch.empty((config.parscale_n, config.num_key_value_heads, config.parscale_n_tokens, self.head_dim)))\n            self.prefix_v = nn.Parameter(torch.empty((config.parscale_n, config.num_key_value_heads, config.parscale_n_tokens, self.head_dim)))\n\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_value: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n        \n        if self.config.parscale_n > 1:\n\n            # Expand attention mask to contain the prefix tokens\n            n_virtual_tokens = self.config.parscale_n_tokens\n\n            if attention_mask is not None:\n                attention_mask = torch.cat([\n                    torch.zeros((attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], self.config.parscale_n_tokens), dtype=attention_mask.dtype, device=attention_mask.device), \n                    attention_mask\n                ], dim=3)\n\n            if query_states.size(2) != 1:\n                query_states = torch.cat([torch.zeros([query_states.size(0), query_states.size(1), n_virtual_tokens, query_states.size(3)], dtype=query_states.dtype, device=query_states.device), query_states], dim=2)\n                if attention_mask is not None:\n                    attention_mask = torch.cat([\n                        torch.zeros((attention_mask.shape[0], attention_mask.shape[1], self.config.parscale_n_tokens, attention_mask.shape[3]), dtype=attention_mask.dtype, device=attention_mask.device), \n                        attention_mask\n                    ], dim=2)\n\n        sliding_window = None\n        if (\n            self.config.use_sliding_window\n            and getattr(self.config, \"sliding_window\", None) is not None\n            and self.layer_idx >= self.config.max_window_layers\n        ):\n            sliding_window = self.config.sliding_window\n\n        attention_interface: Callable = eager_attention_forward\n        if self.config._attn_implementation != \"eager\":\n            if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n                logger.warning_once(\n                    \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to \"\n                    'eager attention. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n                )\n            else:\n                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0 if not self.training else self.attention_dropout,\n            scaling=self.scaling,\n            sliding_window=sliding_window,  # main diff with Llama\n            # is_causal=True,\n            **kwargs,\n        )\n\n        if self.config.parscale_n > 1 and query_states.size(2) != 1:\n            # Remove the prefix part\n            attn_output = attn_output[:, n_virtual_tokens:]\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output, attn_weights\n\n\nclass Qwen2RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        Qwen2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\nclass Qwen2DecoderLayer(nn.Module):\n    def __init__(self, config: Qwen2ParScaleConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)\n        self.mlp = Qwen2MLP(config)\n        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        if config.sliding_window and config._attn_implementation != \"flash_attention_2\":\n            logger.warning_once(\n                f\"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; \"\n                \"unexpected results may be encountered.\"\n            )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            position_embeddings=position_embeddings,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        return outputs\n\n\nclass Qwen2RotaryEmbedding(nn.Module):\n    def __init__(self, config: Qwen2ParScaleConfig, device=None):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:\n            self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))\n        else:\n            self.rope_type = \"default\"\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    def _dynamic_frequency_update(self, position_ids, device):\n        \"\"\"\n        dynamic RoPE layers should recompute `inv_freq` in the following situations:\n        1 - growing beyond the cached sequence length (allow scaling)\n        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)\n        \"\"\"\n        seq_len = torch.max(position_ids) + 1\n        if seq_len > self.max_seq_len_cached:  # growth\n            inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)  # TODO joao: may break with compilation\n            self.max_seq_len_cached = seq_len\n\n        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset\n            # This .to() is needed if the model has been moved to a device after being initialized (because\n            # the buffer is automatically moved, but not the original copy)\n            self.original_inv_freq = self.original_inv_freq.to(device)\n            self.register_buffer(\"inv_freq\", self.original_inv_freq, persistent=False)\n            self.max_seq_len_cached = self.original_max_seq_len\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        if \"dynamic\" in self.rope_type:\n            self._dynamic_frequency_update(position_ids, device=x.device)\n\n        # Core RoPE block\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n\n        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention\n        cos = cos * self.attention_scaling\n        sin = sin * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nQWEN2_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`Qwen2ParScaleConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen2 Model outputting raw hidden-states without any specific head on top.\",\n    QWEN2_START_DOCSTRING,\n)\nclass Qwen2PreTrainedModel(PreTrainedModel):\n    config_class = Qwen2ParScaleConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"Qwen2DecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_flex_attn = True\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nQWEN2_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance, see our\n            [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen2 Model outputting raw hidden-states without any specific head on top.\",\n    QWEN2_START_DOCSTRING,\n)\nclass Qwen2Model(Qwen2PreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]\n\n    Args:\n        config: Qwen2ParScaleConfig\n    \"\"\"\n\n    def __init__(self, config: Qwen2ParScaleConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = Qwen2RotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n\n        self.parscale_n = config.parscale_n\n        if config.parscale_n > 1:\n            self.aggregate_layer = torch.nn.Sequential(\n                torch.nn.Linear(config.parscale_n * config.hidden_size, config.hidden_size),\n                torch.nn.SiLU(),\n                torch.nn.Linear(config.hidden_size, config.parscale_n)\n            )\n        self.parscale_aggregate_attn_smoothing = config.parscale_attn_smooth\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")\n\n        if self.gradient_checkpointing and self.training and use_cache:\n            logger.warning_once(\n                \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\"\n            )\n            use_cache = False\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n        \n        if self.parscale_n > 1:\n            # Input transformation: we directly copy the input for n_parscale times. \n            # The transformation is implemented through KVCache (ParscaleCache).\n            inputs_embeds = repeat(inputs_embeds, \"b s h -> (n_parscale b) s h\", n_parscale=self.parscale_n)\n            if attention_mask is not None:\n                attention_mask = repeat(attention_mask, \"b s -> (n_parscale b) s\", n_parscale=self.parscale_n)\n            if position_ids is not None:\n                position_ids = repeat(position_ids, \"b s -> (n_parscale b) s\", n_parscale=self.parscale_n)\n            \n            # The trained prefix is saved in layer.self_attn.prefix_k / layer.self_attn.prefix_v\n            # We extract them to construct ParscaleCache.\n            if past_key_values is None or past_key_values.get_seq_length() == 0:\n                past_key_values = ParscaleCache([layer.self_attn.prefix_k for layer in self.layers], [layer.self_attn.prefix_v for layer in self.layers])\n\n        if use_cache and past_key_values is None:\n            past_key_values = DynamicCache()\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n        )\n\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n\n        for decoder_layer in self.layers[: self.config.num_hidden_layers]:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                    position_embeddings,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings,\n                    **flash_attn_kwargs,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        if self.parscale_n > 1:\n            # output aggregation, based on dynamic weighted sum.\n            attn = torch.unsqueeze(torch.softmax(self.aggregate_layer(\n                rearrange(hidden_states, \"(n_parscale b) s h -> b s (h n_parscale)\", n_parscale=self.parscale_n)\n            ).float(), dim=-1), dim=-1) # [b s n_parscale 1]\n            if self.parscale_aggregate_attn_smoothing != 0.0:\n                attn = attn * (1 - self.parscale_aggregate_attn_smoothing) + (self.parscale_aggregate_attn_smoothing / self.parscale_n)\n            hidden_states = torch.sum(\n                rearrange(hidden_states, \"(n_parscale b) s h -> b s n_parscale h\", n_parscale=self.parscale_n) * attn, \n                dim=2, keepdim=False\n            ).to(hidden_states.dtype)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        output = BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values if use_cache else None,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n        return output if return_dict else output.to_tuple()\n\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and (attention_mask == 0.0).any():\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if self.config._attn_implementation == \"sdpa\" and not using_static_cache and not output_attentions:\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_cache_shape()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(\n            attention_mask,\n            sequence_length=sequence_length,\n            target_length=target_length,\n            dtype=dtype,\n            device=device,\n            cache_position=cache_position,\n            batch_size=input_tensor.shape[0],\n        )\n\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            min_dtype = torch.finfo(dtype).min\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\n    @staticmethod\n    def _prepare_4d_causal_attention_mask_with_cache_position(\n        attention_mask: torch.Tensor,\n        sequence_length: int,\n        target_length: int,\n        dtype: torch.dtype,\n        device: torch.device,\n        cache_position: torch.Tensor,\n        batch_size: int,\n        **kwargs,\n    ):\n        \"\"\"\n        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape\n        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.\n\n        Args:\n            attention_mask (`torch.Tensor`):\n                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape\n                `(batch_size, 1, query_length, key_value_length)`.\n            sequence_length (`int`):\n                The sequence length being processed.\n            target_length (`int`):\n                The target length: when generating with static cache, the mask should be as long as the static cache,\n                to account for the 0 padding, the part of the cache that is not filled yet.\n            dtype (`torch.dtype`):\n                The dtype to use for the 4D attention mask.\n            device (`torch.device`):\n                The device to plcae the 4D attention mask on.\n            cache_position (`torch.Tensor`):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            batch_size (`torch.Tensor`):\n                Batch size.\n        \"\"\"\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.\n            causal_mask = attention_mask\n        else:\n            min_dtype = torch.finfo(dtype).min\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n\n        return causal_mask\n\n\nclass KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...\n\n\nclass Qwen2ParScaleForCausalLM(Qwen2PreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n    _tp_plan = {\"lm_head\": \"colwise_rep\"}\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = Qwen2Model(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        num_logits_to_keep: int = 0,\n        **kwargs: Unpack[KwargsForCausalLM],\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n            num_logits_to_keep (`int`, *optional*):\n                Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all\n                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that\n                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, Qwen2ForCausalLM\n\n        >>> model = Qwen2ForCausalLM.from_pretrained(\"meta-qwen2/Qwen2-2-7b-hf\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"meta-qwen2/Qwen2-2-7b-hf\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        hidden_states = outputs[0]\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Qwen2 Model transformer with a sequence classification head on top (linear layer).\n\n    [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    QWEN2_START_DOCSTRING,\n)\nclass Qwen2ForSequenceClassification(Qwen2PreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = Qwen2Model(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\"Cannot handle batch sizes > 1 if no padding token is defined.\")\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility\n                sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1\n                sequence_lengths = sequence_lengths % input_ids.shape[-1]\n                sequence_lengths = sequence_lengths.to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)\n\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states\n    output) e.g. for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    QWEN2_START_DOCSTRING,\n)\nclass Qwen2ForTokenClassification(Qwen2PreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = Qwen2Model(config)\n        if getattr(config, \"classifier_dropout\", None) is not None:\n            classifier_dropout = config.classifier_dropout\n        elif getattr(config, \"hidden_dropout\", None) is not None:\n            classifier_dropout = config.hidden_dropout\n        else:\n            classifier_dropout = 0.1\n        self.dropout = nn.Dropout(classifier_dropout)\n        self.score = nn.Linear(config.hidden_size, config.num_labels)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)\n    @add_code_sample_docstrings(\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=TokenClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, TokenClassifierOutput]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        sequence_output = outputs[0]\n        sequence_output = self.dropout(sequence_output)\n        logits = self.score(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits, labels, self.config)\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\nThe Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like\nSQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).\n    \"\"\",\n    QWEN2_START_DOCSTRING,\n)\nclass Qwen2ForQuestionAnswering(Qwen2PreTrainedModel):\n    base_model_prefix = \"transformer\"\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.transformer = Qwen2Model(config)\n        self.qa_outputs = nn.Linear(config.hidden_size, 2)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.transformer.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.transformer.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        start_positions: Optional[torch.LongTensor] = None,\n        end_positions: Optional[torch.LongTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs,\n    ) -> Union[Tuple, QuestionAnsweringModelOutput]:\n        r\"\"\"\n        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.transformer(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output = outputs[0]\n\n        logits = self.qa_outputs(sequence_output)\n        start_logits, end_logits = logits.split(1, dim=-1)\n        start_logits = start_logits.squeeze(-1).contiguous()\n        end_logits = end_logits.squeeze(-1).contiguous()\n\n        loss = None\n        if start_positions is not None and end_positions is not None:\n            loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)\n\n        if not return_dict:\n            output = (start_logits, end_logits) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return QuestionAnsweringModelOutput(\n            loss=loss,\n            start_logits=start_logits,\n            end_logits=end_logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "parametric_fit.py",
    "content": "import numpy as np\nfrom scipy.optimize import minimize\nfrom sklearn.linear_model import LinearRegression\nimport matplotlib.pyplot as plt\nfrom sklearn.metrics import r2_score\nimport json\nimport os\nimport pandas as pd\n\ndef parametric_fit(param_list, p_list, loss_list):\n    param_list = np.asarray(param_list).reshape((-1, ))\n    loss_list = np.asarray(loss_list).reshape((-1, ))\n    p_list = np.asarray(p_list).reshape((-1, ))\n\n    def huber_loss(y_true, y_pred, delta=0.001):\n        error = y_true - y_pred\n        is_small_error = np.abs(error) <= delta\n        squared_loss = np.square(error) / 2\n        linear_loss = delta * (np.abs(error) - delta / 2)\n        return np.where(is_small_error, squared_loss, linear_loss).sum()\n    \n    def pred_loss(params):\n        E, A, alpha, k = params\n        return E + (A * 1e9 / (param_list * (np.log(p_list) * k + 1))) ** alpha\n\n    def objective_function(params):\n        pred = pred_loss(params)\n        return huber_loss(np.log(loss_list), np.log(pred))\n\n    best_param = None\n    best_func = 1000000\n    for E in [-1, -0.5, 0]:\n        for log_A in [-4, -2, 0, 2, 4]:\n            for alpha in [0, 0.5, 1, 1.5, 2]:\n                for k in [0.2, 0.4, 0.6, 0.8]:\n                    initial_params = [np.exp(E), np.exp(log_A), alpha, k]\n                    bounds = [(1e-8, None), (1e-8, None), (1e-8, None), (1e-8, None)]\n                    result = minimize(objective_function, initial_params, method='L-BFGS-B', bounds=bounds)\n                    if result.fun < best_func:\n                        best_param = result.x\n                        best_func = result.fun\n    print(f\"{result = }\")\n    print(f\"{best_param = }\")\n    print(f\"{best_func = }\")\n\n    pred_key = \"$\\\\mathcal L_{\\\\text{pred}}$\"\n    true_key = \"$\\\\mathcal L_{\\\\text{true}}$\"\n    df = pd.DataFrame({\n        \"$P$\": p_list,\n        \"Parameters (Non-Embedding)\": param_list,\n        pred_key: pred_loss(best_param),\n        true_key: loss_list,\n        \"Error\": pred_loss(best_param) - loss_list\n    })\n    df['Parameters (Non-Embedding)'] = df['Parameters (Non-Embedding)'].apply(lambda x: f\"{x:,}\")\n    r2 = r2_score(df[true_key].to_numpy().reshape(-1, 1), df[pred_key].to_numpy().reshape(-1, 1))\n\n    print(df.to_latex(float_format=lambda x: f\"{x:.4f}\", index=False, column_format='rrrrr'))\n    print(f\"{r2 = }\")\n\n\nif __name__ == \"__main__\":\n\n    params = [\n        [535813376, 693753856, 1088376320, 1571472384, 2774773760, 4353203200], \n        [538195842, 696738818, 1092762882, 1577522690, 2784937986, 4368529922],\n        [540577412, 699722756, 1097148164, 1583571460, 2795100164, 4383854084],\n        [545340552, 705690632, 1105918728, 1595669000, 2815424520, 4414502408],\n    ]\n\n    stack_loss = [\n        [1.1722, 1.1496, 1.1131, 1.0817, 1.0451, 1.0213], # 1.0006], # P1 \n        [1.1507, 1.1262, 1.094, 1.0623, 1.0244, 1.0025], # P2\n        [1.1354, 1.1124, 1.0808, 1.049, 1.0126, 0.9906], # P4\n        [1.1231, 1.0997, 1.0688, 1.0383, 1.0016, 0.9794], # P8\n    ]\n\n    pile_loss = [\n        [2.1113, 2.0671, 2.0027, 1.9539, 1.8876, 1.8451], # P1\n        [2.0772, 2.0363, 1.973, 1.9266, 1.861, 1.8137], # P2\n        [2.0544, 2.0128, 1.9509, 1.904, 1.8394, 1.7938], # P4\n        [2.0364, 1.9933, 1.9318, 1.8856, 1.8218, 1.7772], # P8\n    ]\n\n    p = [\n        [1] * 6,\n        [2] * 6,\n        [4] * 6,\n        [8] * 6,\n    ]\n    \n    print(\"=\" * 10 + \" Stack-V2 Python \" + \"=\" * 10)\n    parametric_fit(params, p, stack_loss)\n    print(\"=\" * 10 + \" Pile \" + \"=\" * 10)\n    parametric_fit(params, p, pile_loss)"
  }
]