Repository: QwenLM/ParScale
Branch: main
Commit: cd6acb48ba6d
Files: 5
Total size: 82.2 KB
Directory structure:
gitextract_brf28tsi/
├── README.md
├── configuration_qwen2_parscale.py
├── cost_analysis.py
├── modeling_qwen2_parscale.py
└── parametric_fit.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
<div align="center">
# Parallel Scaling Law for Language Model
_Yet Another Scaling Law beyond Parameters and Inference Time Scaling_
[](https://arxiv.org/abs/2505.10475)
[](https://huggingface.co/ParScale)
<div align="center">
<img src="figures/logo.jpg" style="width: 10%;" />
</div>
<p align="center">
💡 <a href="#-key-findings">Key Findings</a>
| 📈 <a href="#-scaling-law">Scaling Law</a>
| ⚡ <a href="#-cost-analysis">Cost Analysis</a>
| 🔥 <a href="#-models">Models</a>
| 📚 <a href="#-citation">Citation</a>
</p>
</div>
## 🌟 About
- Most believe that scaling language models requires a heavy cost in either **space** (parameter scaling) or **time** (inference-time scaling).
- We introduce the *third* scaling paradigm for scaling LLMs: leverages **parallel computation** during both training and inference time (Parallel Scaling, or *ParScale*).
- We apply $P$ diverse and learnable transformations to the input, execute forward passes of the model in parallel, and dynamically aggregate the $P$ outputs.
<div align="center">
<img src="figures/teaser.png" style="width: 80%;" />
</div>
---
## 💡 Key Findings
<div align="center">
<img src="figures/scaling_comparison.png" style="width: 80%;" />
</div>
Here are the core insights and benefits distilled from our theoretical analysis and empirical evaluations:
📈 **Logarithmic Scaling Law**: We theoretically and empirically establish that **scaling with $P$ parallel streams is comparable to scaling the number of parameters by** $O(\log P)$. This suggests that parallel computation can serve as an efficient substitute for parameter growth, especially for larger models.
✅ **Universal Applicability**: Unlike inference-time scaling which requires specialized data and limited application, it works with any model architecture, optimization method, data, or downstream task.
🧠 **Stronger Performance on Reasoning Tasks**: Reasoning-intensive tasks (e.g., coding or math) benefit more from ParScale, which suggests that scaling computation can effectively push the boundary of reasoning.
⚡ **Superior Inference Efficiency**: ParScale can use up to **22x less memory increase** and **6x less latency increase** compared to parameter scaling that achieves the same performance improvement (batch size=1).
🧱 **Cost-Efficient Training via Two-Stage Strategy**: Training a parallel-scaled model doesn't require starting from scratch. With a two-stage training strategy, we can post-train ithe parallel components using only a small amount of data.
🔁 **Dynamic Adaptation at Inference Time**: We find that ParScale remains effective with frozen main parameters for different $P$. This illustrates the potential of dynamic parallel scaling: switching $P$ to dynamically adapt model capabilities during inference.
We release the inference code in `modeling_qwen2_parscale.py` and `configuration_qwen2_parscale.py`. Our 67 checkpoints is available at [🤗 HuggingFace](https://huggingface.co/ParScale).
---
## 📈 Scaling Law
- We carry out large-scale pre-training experiments on the Stack-V2 and Pile corpus, by ranging $P$ from 1 to 8 and model parameters from 500M to 4.4B.
- We use the results to fit a new *parallel scaling law* that generalizes the Chinchilla scaling law.
- We release our parametric fitting code in `parametric_fit.py`.
- Feel free to try [🤗 HuggingFace Space](https://huggingface.co/spaces/ParScale/Parallel_Scaling_Law) for a nice visualization for the parallel scaling law!
<div align="center">
<img src="figures/scaling_law.png" style="width: 70%;" />
<img src="figures/scaling_law2.png" style="width: 70%;" />
</div>
---
## ⚡ Cost Analysis
<div align="center">
<img src="figures/cost.png" style="width: 70%;" />
</div>
- We further compare the inference efficiency between parallel scaling and parameter scaling at equivalent performance levels.
- We release our analysis code in `cost_analysis.py`. Before using it, you should first install [llm-analysis](https://github.com/cli99/llm-analysis):
```bash
git clone https://github.com/cli99/llm-analysis.git
cd llm-analysis
pip install .
```
- You can use the following command to analyze the inference memory and latency cost for our 4.4B model, with $P=2$ and batch size=2:
```bash
python cost_analysis.py --hidden_size 2560 --intermediate_size 13824 --P 2 --batch_size 2
```
---
## 🔥 Models
✨ are our recommendation for strong models!
### Base models for scaling training data to 1T tokens
These models demonstrate strong competitiveness among existing small models, including SmolLM, gemma, and Llama-3.2.
|Model|Description|Download|
|:-:|:-:|:-:|
|ParScale-1.8B-P1|✨ Baseline $P=1$|[🤗 ParScale/ParScale-1.8B-P1](https://huggingface.co/ParScale/ParScale-1.8B-P1)|
|ParScale-1.8B-P2|✨ ParScale $P=2$|[🤗 ParScale/ParScale-1.8B-P2](https://huggingface.co/ParScale/ParScale-1.8B-P2)|
|ParScale-1.8B-P4|✨ ParScale $P=4$|[🤗 ParScale/ParScale-1.8B-P4](https://huggingface.co/ParScale/ParScale-1.8B-P4)|
|ParScale-1.8B-P8|✨ ParScale $P=8$|[🤗 ParScale/ParScale-1.8B-P8](https://huggingface.co/ParScale/ParScale-1.8B-P8)|
### Instruct models for scaling training data to 1T tokens
We post-trained the aforementioned base model on SmolTalk-1M to enable conversational capabilities.
|Model|Description|Download|
|:-:|:-:|:-:|
|ParScale-1.8B-P1-Inst|✨ Baseline $P=1$|[🤗 ParScale/ParScale-1.8B-P1-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P1-Inst)|
|ParScale-1.8B-P2-Inst|✨ ParScale $P=2$|[🤗 ParScale/ParScale-1.8B-P2-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P2-Inst)|
|ParScale-1.8B-P4-Inst|✨ ParScale $P=4$|[🤗 ParScale/ParScale-1.8B-P4-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P4-Inst)|
|ParScale-1.8B-P8-Inst|✨ ParScale $P=8$|[🤗 ParScale/ParScale-1.8B-P8-Inst](https://huggingface.co/ParScale/ParScale-1.8B-P8-Inst)|
### Continual Pretraining Qwen-2.5-3B
We froze the parameters of Qwen-2.5-3B and only fine-tuned the newly introduced parameters on Stack-V2-Python. Since the following models share the same backbone parameters as Qwen-2.5-3B, they have the potential for dynamic ParScale: switching P to adapt model capabilities during inference.
|Model|Description|Download|
|:-:|:-:|:-:|
|ParScale-Qwen-3B-P2-Python|✨ ParScale $P=2$|[🤗 ParScale/ParScale-Qwen-3B-P2-Python](https://huggingface.co/ParScale/ParScale-Qwen-3B-P2-Python)|
|ParScale-Qwen-3B-P4-Python|✨ ParScale $P=4$|[🤗 ParScale/ParScale-Qwen-3B-P4-Python](https://huggingface.co/ParScale/ParScale-Qwen-3B-P4-Python)|
|ParScale-Qwen-3B-P8-Python|✨ ParScale $P=8$|[🤗 ParScale/ParScale-Qwen-3B-P8-Python](https://huggingface.co/ParScale/ParScale-Qwen-3B-P8-Python)|
- For full continual pretraining on Stack-V2-Python
|Model|Description|Download|
|:-:|:-:|:-:|
|ParScale-QwenInit-3B-P1-Python|Baseline $P=1$|[🤗 ParScale/ParScale-QwenInit-3B-P1-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P1-Python)|
|ParScale-QwenInit-3B-P2-Python|ParScale $P=2$|[🤗 ParScale/ParScale-QwenInit-3B-P2-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P2-Python)|
|ParScale-QwenInit-3B-P4-Python|ParScale $P=4$|[🤗 ParScale/ParScale-QwenInit-3B-P4-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P4-Python)|
|ParScale-QwenInit-3B-P8-Python|ParScale $P=8$|[🤗 ParScale/ParScale-QwenInit-3B-P8-Python](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P8-Python)|
- For full continual pretraining on Pile
|Model|Description|Download|
|:-:|:-:|:-:|
|ParScale-QwenInit-3B-P1-Pile|Baseline $P=1$|[🤗 ParScale/ParScale-QwenInit-3B-P1-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P1-Pile)|
|ParScale-QwenInit-3B-P2-Pile|ParScale $P=2$|[🤗 ParScale/ParScale-QwenInit-3B-P2-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P2-Pile)|
|ParScale-QwenInit-3B-P4-Pile|ParScale $P=4$|[🤗 ParScale/ParScale-QwenInit-3B-P4-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P4-Pile)|
|ParScale-QwenInit-3B-P8-Pile|ParScale $P=8$|[🤗 ParScale/ParScale-QwenInit-3B-P8-Pile](https://huggingface.co/ParScale/ParScale-QwenInit-3B-P8-Pile)|
### Checkpoints Used to Fit the Scaling Law
Download link: https://huggingface.co/ParScale/ParScale-{size}-{P}-{dataset}
- {size}: model size, from {0.7B, 0.9B, 1.3B, 1.8B, 3B, 4.7B}
- {P}: number of parallels, from {P1, P2, P4, P8}
- {dataset}: training dataset, from {Python, Pile}
- $6\times 4 \times 2=48$ checkpoints in total.
### Usage Example with 🤗 Hugging Face
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
name = "ParScale/ParScale-1.8B-P8" # or anything else you like
model = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(name)
inputs = tokenizer.encode("Hello, how are you today?", return_tensors="pt").to("cuda")
outputs = model.generate(inputs, max_new_tokens=128)[0]
print(tokenizer.decode(outputs))
```
## 📚 Citation
```bibtex
@article{ParScale,
title={Parallel Scaling Law for Language Models},
author={Mouxiang Chen and Binyuan Hui and Zeyu Cui and Jiaxi Yang and Dayiheng Liu and Jianling Sun and Junyang Lin and Zhongxin Liu},
year={2025},
eprint={2505.10475},
archivePrefix={arXiv},
primaryClass={cs.LG},
journal={arXiv preprint arXiv:2505.10475},
url={https://arxiv.org/abs/2505.10475},
}
```
================================================
FILE: configuration_qwen2_parscale.py
================================================
"""Qwen2 model configuration, with support for ParScale"""
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Qwen2ParScaleConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen2Model, Qwen2Config
>>> # Initializing a Qwen2 style configuration
>>> configuration = Qwen2Config()
>>> # Initializing a model from the Qwen2-7B style configuration
>>> model = Qwen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2_parscale"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
parscale_n=1,
parscale_n_tokens=48,
parscale_attn_smooth=0.01,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if use_sliding_window else None
self.max_window_layers = max_window_layers
self.parscale_n = parscale_n
self.parscale_n_tokens = parscale_n_tokens
self.parscale_attn_smooth = parscale_attn_smooth
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
================================================
FILE: cost_analysis.py
================================================
import numpy as np
import json
import os
from llm_analysis.analysis import LLMAnalysis, get_gpu_config_by_name, ModelConfig, ActivationRecomputation, BYTES_FP16
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
# General model config
parser.add_argument('--hidden_size', type=int, required=True)
parser.add_argument('--intermediate_size', type=int, required=True)
parser.add_argument('--num_hidden_layers', type=int, default=36)
parser.add_argument('--num_attention_heads', type=int, default=16)
parser.add_argument('--max_position_embeddings', type=int, default=2048)
parser.add_argument('--num_key_value_heads', type=int, default=2)
parser.add_argument('--vocab_size', type=int, default=151936)
# Parscale config
parser.add_argument('--P', type=int, default=1) # Number of parallel streams
parser.add_argument('--parscale_prefix_tokens', type=int, default=48) # Number of prefix tokens
# Data config
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--input_length', type=int, default=64)
parser.add_argument('--output_length', type=int, default=64)
# GPU config
parser.add_argument('--gpu_config', type=str, default="a100-sxm-80gb")
parser.add_argument('--flops_efficiency', type=float, default=0.7) # Recommended by llm-analysis
parser.add_argument('--hbm_memory_efficiency', type=float, default=0.9) # Recommended by llm-analysis
args = parser.parse_args()
p = args.P
model_config = ModelConfig(
name="",
num_layers=args.num_hidden_layers,
n_head=args.num_attention_heads,
hidden_dim=args.hidden_size, vocab_size=args.vocab_size,
max_seq_len=args.max_position_embeddings + (args.parscale_prefix_tokens if p > 1 else 0),
num_key_value_heads=args.num_key_value_heads,
ffn_embed_dim=args.intermediate_size,
mlp_gated_linear_units=True
)
gpu_config = get_gpu_config_by_name("a100-sxm-80gb")
gpu_config.mem_per_GPU_in_GB = 10000
analysis = LLMAnalysis(
model_config,
gpu_config,
flops_efficiency=0.7,
hbm_memory_efficiency=0.9,
)
seq_len = args.input_length + (args.parscale_prefix_tokens if p > 1 else 0)
summary_dict = analysis.inference(
batch_size_per_gpu=args.batch_size * p,
seq_len=seq_len,
num_tokens_to_generate=args.output_length,
)
# We consider the influence of the aggregation layer.
aggregate_param = (args.hidden_size + 1) * args.hidden_size * p if p > 1 else 0
aggregate_param_vs_fwd_param = aggregate_param / analysis.get_num_params_per_layer_mlp()
aggregate_latency = aggregate_param_vs_fwd_param * analysis.get_latency_fwd_per_layer_mlp(args.batch_size, args.input_length + args.output_length) if p > 1 else 0
aggregate_memory = aggregate_param * analysis.dtype_config.weight_bits / 8
prefill_activation_memory_per_gpu = max(
# Each layer's activation memory will increase by P times
analysis.get_activation_memory_per_layer(
args.batch_size * p,
seq_len,
is_inference=True,
layernorm_dtype_bytes=BYTES_FP16,
),
# The embedding's activation memory will not participate in parallel and independent of P.
analysis.get_activation_memory_output_embedding(
args.batch_size, seq_len
)
)
# Since we use batch_size * p as the new batch size, the latency for llm-analysis assumes the embedding latency is also computed in this new batch size. However, ParScale will not increase the computation for embedding.
# Therefore, we should make a fix toward it.
embedding_latency_estimate_for_embedding = (
analysis.get_latency_fwd_input_embedding(args.batch_size * p, args.input_length + args.output_length, dtype_bytes=analysis.dtype_config.embedding_bits) +
analysis.get_latency_fwd_output_embedding_loss(args.batch_size * p, args.input_length + args.output_length)
)
embedding_latency_real_for_embedding = (
analysis.get_latency_fwd_input_embedding(args.batch_size, args.input_length + args.output_length, dtype_bytes=analysis.dtype_config.embedding_bits) +
analysis.get_latency_fwd_output_embedding_loss(args.batch_size, args.input_length + args.output_length)
)
total_memory = (
summary_dict['kv_cache_memory_per_gpu'] +
summary_dict['weight_memory_per_gpu'] +
aggregate_memory +
prefill_activation_memory_per_gpu
)
total_latency = (
summary_dict['total_latency'] + aggregate_latency
- embedding_latency_estimate_for_embedding
+ embedding_latency_real_for_embedding
)
print(f"Memory: {total_memory / 2**30:.3f}GB; Latency: {total_latency:.3f}s")
================================================
FILE: modeling_qwen2_parscale.py
================================================
"""
This is the inference code for ParScale, Based on Qwen2. It can be used directly to load existing Qwen2 models (setting parscale_n = 1 by default).
All modifications are wrapped within the condition 'parscale_n > 1'.
If you are interested in how ParScale is implemented, please search for "parscale_n" in this file.
"""
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import nn
from einops import repeat, rearrange
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import (
LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_qwen2_parscale import Qwen2ParScaleConfig
from typing import Any, Dict, List, Optional, Tuple, Union
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
_CONFIG_FOR_DOC = "Qwen2ParScaleConfig"
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class ParscaleCache(DynamicCache):
def __init__(self, prefix_k, prefix_v) -> None:
super().__init__()
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self.key_cache: List[torch.Tensor] = prefix_k
self.value_cache: List[torch.Tensor] = prefix_v
self.parscale_n = prefix_k[0].size(0)
self.n_prefix_tokens = prefix_k[0].size(2)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.key_cache[layer_idx].size(0) != key_states.size(0):
# first time generation
self.key_cache[layer_idx] = repeat(self.key_cache[layer_idx], 'n_parscale ... -> (n_parscale b) ...', b=key_states.size(0) // self.parscale_n)
self.value_cache[layer_idx] = repeat(self.value_cache[layer_idx], 'n_parscale ... -> (n_parscale b) ...', b=key_states.size(0) // self.parscale_n)
return super().update(key_states, value_states, layer_idx, cache_kwargs)
def get_seq_length(self, layer_idx = 0):
seq_len = super().get_seq_length(layer_idx)
if seq_len != 0:
seq_len -= self.n_prefix_tokens
return seq_len
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
b = self.key_cache[0].size(0) // self.parscale_n
beam_idx = torch.cat([beam_idx + b * i for i in range(self.parscale_n)])
super().reorder_cache(beam_idx)
class Qwen2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Qwen2ParScaleConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
if config.parscale_n > 1:
self.prefix_k = nn.Parameter(torch.empty((config.parscale_n, config.num_key_value_heads, config.parscale_n_tokens, self.head_dim)))
self.prefix_v = nn.Parameter(torch.empty((config.parscale_n, config.num_key_value_heads, config.parscale_n_tokens, self.head_dim)))
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if self.config.parscale_n > 1:
# Expand attention mask to contain the prefix tokens
n_virtual_tokens = self.config.parscale_n_tokens
if attention_mask is not None:
attention_mask = torch.cat([
torch.zeros((attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], self.config.parscale_n_tokens), dtype=attention_mask.dtype, device=attention_mask.device),
attention_mask
], dim=3)
if query_states.size(2) != 1:
query_states = torch.cat([torch.zeros([query_states.size(0), query_states.size(1), n_virtual_tokens, query_states.size(3)], dtype=query_states.dtype, device=query_states.device), query_states], dim=2)
if attention_mask is not None:
attention_mask = torch.cat([
torch.zeros((attention_mask.shape[0], attention_mask.shape[1], self.config.parscale_n_tokens, attention_mask.shape[3]), dtype=attention_mask.dtype, device=attention_mask.device),
attention_mask
], dim=2)
sliding_window = None
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=sliding_window, # main diff with Llama
# is_causal=True,
**kwargs,
)
if self.config.parscale_n > 1 and query_states.size(2) != 1:
# Remove the prefix part
attn_output = attn_output[:, n_virtual_tokens:]
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config: Qwen2ParScaleConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if config.sliding_window and config._attn_implementation != "flash_attention_2":
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2ParScaleConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
QWEN2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Qwen2ParScaleConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
QWEN2_START_DOCSTRING,
)
class Qwen2PreTrainedModel(PreTrainedModel):
config_class = Qwen2ParScaleConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
QWEN2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
QWEN2_START_DOCSTRING,
)
class Qwen2Model(Qwen2PreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
Args:
config: Qwen2ParScaleConfig
"""
def __init__(self, config: Qwen2ParScaleConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.parscale_n = config.parscale_n
if config.parscale_n > 1:
self.aggregate_layer = torch.nn.Sequential(
torch.nn.Linear(config.parscale_n * config.hidden_size, config.hidden_size),
torch.nn.SiLU(),
torch.nn.Linear(config.hidden_size, config.parscale_n)
)
self.parscale_aggregate_attn_smoothing = config.parscale_attn_smooth
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self.parscale_n > 1:
# Input transformation: we directly copy the input for n_parscale times.
# The transformation is implemented through KVCache (ParscaleCache).
inputs_embeds = repeat(inputs_embeds, "b s h -> (n_parscale b) s h", n_parscale=self.parscale_n)
if attention_mask is not None:
attention_mask = repeat(attention_mask, "b s -> (n_parscale b) s", n_parscale=self.parscale_n)
if position_ids is not None:
position_ids = repeat(position_ids, "b s -> (n_parscale b) s", n_parscale=self.parscale_n)
# The trained prefix is saved in layer.self_attn.prefix_k / layer.self_attn.prefix_v
# We extract them to construct ParscaleCache.
if past_key_values is None or past_key_values.get_seq_length() == 0:
past_key_values = ParscaleCache([layer.self_attn.prefix_k for layer in self.layers], [layer.self_attn.prefix_v for layer in self.layers])
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if self.parscale_n > 1:
# output aggregation, based on dynamic weighted sum.
attn = torch.unsqueeze(torch.softmax(self.aggregate_layer(
rearrange(hidden_states, "(n_parscale b) s h -> b s (h n_parscale)", n_parscale=self.parscale_n)
).float(), dim=-1), dim=-1) # [b s n_parscale 1]
if self.parscale_aggregate_attn_smoothing != 0.0:
attn = attn * (1 - self.parscale_aggregate_attn_smoothing) + (self.parscale_aggregate_attn_smoothing / self.parscale_n)
hidden_states = torch.sum(
rearrange(hidden_states, "(n_parscale b) s h -> b s n_parscale h", n_parscale=self.parscale_n) * attn,
dim=2, keepdim=False
).to(hidden_states.dtype)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Qwen2ParScaleForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = Qwen2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
>>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The Qwen2 Model transformer with a sequence classification head on top (linear layer).
[`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
QWEN2_START_DOCSTRING,
)
class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Qwen2Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
QWEN2_START_DOCSTRING,
)
class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Qwen2Model(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.config)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
QWEN2_START_DOCSTRING,
)
class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel):
base_model_prefix = "transformer"
def __init__(self, config):
super().__init__(config)
self.transformer = Qwen2Model(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.transformer.embed_tokens
def set_input_embeddings(self, value):
self.transformer.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
loss = None
if start_positions is not None and end_positions is not None:
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return QuestionAnsweringModelOutput(
loss=loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
================================================
FILE: parametric_fit.py
================================================
import numpy as np
from scipy.optimize import minimize
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
import json
import os
import pandas as pd
def parametric_fit(param_list, p_list, loss_list):
param_list = np.asarray(param_list).reshape((-1, ))
loss_list = np.asarray(loss_list).reshape((-1, ))
p_list = np.asarray(p_list).reshape((-1, ))
def huber_loss(y_true, y_pred, delta=0.001):
error = y_true - y_pred
is_small_error = np.abs(error) <= delta
squared_loss = np.square(error) / 2
linear_loss = delta * (np.abs(error) - delta / 2)
return np.where(is_small_error, squared_loss, linear_loss).sum()
def pred_loss(params):
E, A, alpha, k = params
return E + (A * 1e9 / (param_list * (np.log(p_list) * k + 1))) ** alpha
def objective_function(params):
pred = pred_loss(params)
return huber_loss(np.log(loss_list), np.log(pred))
best_param = None
best_func = 1000000
for E in [-1, -0.5, 0]:
for log_A in [-4, -2, 0, 2, 4]:
for alpha in [0, 0.5, 1, 1.5, 2]:
for k in [0.2, 0.4, 0.6, 0.8]:
initial_params = [np.exp(E), np.exp(log_A), alpha, k]
bounds = [(1e-8, None), (1e-8, None), (1e-8, None), (1e-8, None)]
result = minimize(objective_function, initial_params, method='L-BFGS-B', bounds=bounds)
if result.fun < best_func:
best_param = result.x
best_func = result.fun
print(f"{result = }")
print(f"{best_param = }")
print(f"{best_func = }")
pred_key = "$\\mathcal L_{\\text{pred}}$"
true_key = "$\\mathcal L_{\\text{true}}$"
df = pd.DataFrame({
"$P$": p_list,
"Parameters (Non-Embedding)": param_list,
pred_key: pred_loss(best_param),
true_key: loss_list,
"Error": pred_loss(best_param) - loss_list
})
df['Parameters (Non-Embedding)'] = df['Parameters (Non-Embedding)'].apply(lambda x: f"{x:,}")
r2 = r2_score(df[true_key].to_numpy().reshape(-1, 1), df[pred_key].to_numpy().reshape(-1, 1))
print(df.to_latex(float_format=lambda x: f"{x:.4f}", index=False, column_format='rrrrr'))
print(f"{r2 = }")
if __name__ == "__main__":
params = [
[535813376, 693753856, 1088376320, 1571472384, 2774773760, 4353203200],
[538195842, 696738818, 1092762882, 1577522690, 2784937986, 4368529922],
[540577412, 699722756, 1097148164, 1583571460, 2795100164, 4383854084],
[545340552, 705690632, 1105918728, 1595669000, 2815424520, 4414502408],
]
stack_loss = [
[1.1722, 1.1496, 1.1131, 1.0817, 1.0451, 1.0213], # 1.0006], # P1
[1.1507, 1.1262, 1.094, 1.0623, 1.0244, 1.0025], # P2
[1.1354, 1.1124, 1.0808, 1.049, 1.0126, 0.9906], # P4
[1.1231, 1.0997, 1.0688, 1.0383, 1.0016, 0.9794], # P8
]
pile_loss = [
[2.1113, 2.0671, 2.0027, 1.9539, 1.8876, 1.8451], # P1
[2.0772, 2.0363, 1.973, 1.9266, 1.861, 1.8137], # P2
[2.0544, 2.0128, 1.9509, 1.904, 1.8394, 1.7938], # P4
[2.0364, 1.9933, 1.9318, 1.8856, 1.8218, 1.7772], # P8
]
p = [
[1] * 6,
[2] * 6,
[4] * 6,
[8] * 6,
]
print("=" * 10 + " Stack-V2 Python " + "=" * 10)
parametric_fit(params, p, stack_loss)
print("=" * 10 + " Pile " + "=" * 10)
parametric_fit(params, p, pile_loss)
gitextract_brf28tsi/ ├── README.md ├── configuration_qwen2_parscale.py ├── cost_analysis.py ├── modeling_qwen2_parscale.py └── parametric_fit.py
SYMBOL INDEX (63 symbols across 3 files)
FILE: configuration_qwen2_parscale.py
class Qwen2ParScaleConfig (line 11) | class Qwen2ParScaleConfig(PretrainedConfig):
method __init__ (line 129) | def __init__(
FILE: modeling_qwen2_parscale.py
class Qwen2MLP (line 46) | class Qwen2MLP(nn.Module):
method __init__ (line 47) | def __init__(self, config):
method forward (line 57) | def forward(self, x):
function rotate_half (line 62) | def rotate_half(x):
function apply_rotary_pos_emb (line 69) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_di...
function repeat_kv (line 96) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
function eager_attention_forward (line 108) | def eager_attention_forward(
class ParscaleCache (line 133) | class ParscaleCache(DynamicCache):
method __init__ (line 134) | def __init__(self, prefix_k, prefix_v) -> None:
method update (line 141) | def update(
method get_seq_length (line 154) | def get_seq_length(self, layer_idx = 0):
method reorder_cache (line 160) | def reorder_cache(self, beam_idx: torch.LongTensor):
class Qwen2Attention (line 166) | class Qwen2Attention(nn.Module):
method __init__ (line 169) | def __init__(self, config: Qwen2ParScaleConfig, layer_idx: int):
method forward (line 187) | def forward(
class Qwen2RMSNorm (line 269) | class Qwen2RMSNorm(nn.Module):
method __init__ (line 270) | def __init__(self, hidden_size, eps=1e-6):
method forward (line 278) | def forward(self, hidden_states):
method extra_repr (line 285) | def extra_repr(self):
class Qwen2DecoderLayer (line 289) | class Qwen2DecoderLayer(nn.Module):
method __init__ (line 290) | def __init__(self, config: Qwen2ParScaleConfig, layer_idx: int):
method forward (line 303) | def forward(
class Qwen2RotaryEmbedding (line 346) | class Qwen2RotaryEmbedding(nn.Module):
method __init__ (line 347) | def __init__(self, config: Qwen2ParScaleConfig, device=None):
method _dynamic_frequency_update (line 364) | def _dynamic_frequency_update(self, position_ids, device):
method forward (line 384) | def forward(self, x, position_ids):
class Qwen2PreTrainedModel (line 428) | class Qwen2PreTrainedModel(PreTrainedModel):
method _init_weights (line 441) | def _init_weights(self, module):
class Qwen2Model (line 532) | class Qwen2Model(Qwen2PreTrainedModel):
method __init__ (line 540) | def __init__(self, config: Qwen2ParScaleConfig):
method get_input_embeddings (line 565) | def get_input_embeddings(self):
method set_input_embeddings (line 568) | def set_input_embeddings(self, value):
method forward (line 572) | def forward(
method _update_causal_mask (line 704) | def _update_causal_mask(
method _prepare_4d_causal_attention_mask_with_cache_position (line 769) | def _prepare_4d_causal_attention_mask_with_cache_position(
class KwargsForCausalLM (line 825) | class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Qwen2ParScaleForCausalLM (line 828) | class Qwen2ParScaleForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
method __init__ (line 832) | def __init__(self, config):
method get_input_embeddings (line 841) | def get_input_embeddings(self):
method set_input_embeddings (line 844) | def set_input_embeddings(self, value):
method get_output_embeddings (line 847) | def get_output_embeddings(self):
method set_output_embeddings (line 850) | def set_output_embeddings(self, new_embeddings):
method set_decoder (line 853) | def set_decoder(self, decoder):
method get_decoder (line 856) | def get_decoder(self):
method forward (line 861) | def forward(
class Qwen2ForSequenceClassification (line 964) | class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
method __init__ (line 965) | def __init__(self, config):
method get_input_embeddings (line 974) | def get_input_embeddings(self):
method set_input_embeddings (line 977) | def set_input_embeddings(self, value):
method forward (line 981) | def forward(
class Qwen2ForTokenClassification (line 1060) | class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
method __init__ (line 1061) | def __init__(self, config):
method get_input_embeddings (line 1077) | def get_input_embeddings(self):
method set_input_embeddings (line 1080) | def set_input_embeddings(self, value):
method forward (line 1089) | def forward(
class Qwen2ForQuestionAnswering (line 1148) | class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel):
method __init__ (line 1151) | def __init__(self, config):
method get_input_embeddings (line 1159) | def get_input_embeddings(self):
method set_input_embeddings (line 1162) | def set_input_embeddings(self, value):
method forward (line 1166) | def forward(
FILE: parametric_fit.py
function parametric_fit (line 10) | def parametric_fit(param_list, p_list, loss_list):
Condensed preview — 5 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (87K chars).
[
{
"path": "README.md",
"chars": 9494,
"preview": "<div align=\"center\">\n\n\n# Parallel Scaling Law for Language Model\n\n\n_Yet Another Scaling Law beyond Parameters and Infere"
},
{
"path": "configuration_qwen2_parscale.py",
"chars": 10164,
"preview": "\"\"\"Qwen2 model configuration, with support for ParScale\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfi"
},
{
"path": "cost_analysis.py",
"chars": 4853,
"preview": "import numpy as np\nimport json\nimport os\nfrom llm_analysis.analysis import LLMAnalysis, get_gpu_config_by_name, ModelCon"
},
{
"path": "modeling_qwen2_parscale.py",
"chars": 56076,
"preview": "\"\"\"\nThis is the inference code for ParScale, Based on Qwen2. It can be used directly to load existing Qwen2 models (sett"
},
{
"path": "parametric_fit.py",
"chars": 3537,
"preview": "import numpy as np\nfrom scipy.optimize import minimize\nfrom sklearn.linear_model import LinearRegression\nimport matplotl"
}
]
About this extraction
This page contains the full source code of the QwenLM/ParScale GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 5 files (82.2 KB), approximately 20.0k tokens, and a symbol index with 63 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.