Showing preview only (2,715K chars total). Download the full file or copy to clipboard to get everything.
Repository: qiqiApink/MotionGPT
Branch: main
Commit: a1c939b34b8f
Files: 72
Total size: 2.6 MB
Directory structure:
gitextract_f4zctl15/
├── README.md
├── dataloader/
│ ├── eval_loader.py
│ ├── tokenizer_loader.py
│ └── vqvae_loader.py
├── environment.yml
├── eval.py
├── eval_vqvae.py
├── finetune_motion.py
├── generate.py
├── generate_batch.py
├── generate_motion.py
├── index.html
├── lit_llama/
│ ├── __init__.py
│ ├── adapter.py
│ ├── indexed_dataset.py
│ ├── lora.py
│ ├── model.py
│ ├── quantization.py
│ ├── tokenizer.py
│ └── utils.py
├── models/
│ ├── encdec.py
│ ├── evaluator_wrapper.py
│ ├── modules.py
│ ├── quantize_cnn.py
│ ├── resnet.py
│ ├── rotation2xyz.py
│ ├── smpl.py
│ └── vqvae.py
├── options/
│ ├── get_eval_option.py
│ ├── option.py
│ └── option_vqvae.py
├── prepare/
│ ├── download_evaluators.sh
│ ├── download_glove.sh
│ ├── download_lora.sh
│ ├── download_smpl.sh
│ └── download_vqvae.sh
├── scripts/
│ ├── convert_checkpoint.py
│ ├── convert_hf_checkpoint.py
│ ├── download.py
│ ├── generate_dataset.py
│ ├── prepare_data.py
│ └── prepare_motion.py
├── sitemap.xml
├── static/
│ ├── css/
│ │ ├── bulma.css.map.txt
│ │ └── index.css
│ └── js/
│ ├── bulma-carousel.js
│ ├── bulma-slider.js
│ └── index.js
├── train_vqvae.py
├── utils/
│ ├── config.py
│ ├── evaluate.py
│ ├── losses.py
│ ├── motion_process.py
│ ├── paramUtil.py
│ ├── quaternion.py
│ ├── rotation_conversions.py
│ ├── skeleton.py
│ ├── utils_model.py
│ └── word_vectorizer.py
├── visualization/
│ ├── plot_3d_global.py
│ └── render.py
└── visualize/
├── joints2smpl/
│ ├── smpl_models/
│ │ ├── SMPL_downsample_index.pkl
│ │ ├── gmm_08.pkl
│ │ ├── neutral_smpl_mean_params.h5
│ │ └── smplx_parts_segm.pkl
│ └── src/
│ ├── config.py
│ ├── customloss.py
│ ├── prior.py
│ └── smplify.py
├── render_mesh.py
├── simplify_loc2rot.py
└── vis_utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# MotionGPT: Finetuned LLMs are General-Purpose Motion Generators
[](https://arxiv.org/abs/2306.10900)
The official PyTorch implementation of the paper "[MotionGPT: Finetuned LLMs are General-Purpose Motion Generators](https://arxiv.org/abs/2306.10900)".
Please visit our [Project Page](https://qiqiapink.github.io/MotionGPT) for more details.

If you find MotionGPT useful for your work please cite:
```
@article{zhang2023motiongpt,
title={MotionGPT: Finetuned LLMs are General-Purpose Motion Generators},
author={Zhang, Yaqi and Huang, Di and Liu, Bin and Tang, Shixiang and Lu, Yan and Chen, Lu and Bai, Lei and Chu, Qi and Yu, Nenghai and Ouyang, Wanli},
journal={arXiv preprint arXiv:2306.10900},
year={2023}
}
```
## Table of Content
* [Installation](#installation)
* [Demo](#demo)
* [Train](#train)
* [Evaluation](#evaluation)
* [Visualization](#visualization)
* [Acknowledgement](#acknowledgement)
## Installation
## 1. Environment
```
conda env create -f environment.yml
conda activate motiongpt
```
## 2. Dependencies
For text to motion evaluation
```
bash prepare/download_evaluators.sh
bash prepare/download_glove.sh
```
For SMPL mesh rendering
```
bash prepare/download_smpl.sh
```
For using the LLaMa model weight, follow [pyllama](https://github.com/juncongmoo/pyllama) to download the original LLaMA model, and then follow [Lit-LLaMA](https://github.com/Lightning-AI/lit-llama) to convert the weights to the Lit-LLaMA format. After this process, please move the `lit-llama/` directory under the `checkpoints/` directory.
Once downloaded, you should have a folder like this:
```
MotionGPT
├── checkpoints
│ ├── kit
│ │ ├── Comp_v6_KLD005
│ │ ├── Decomp_SP001_SM001_H512
│ │ ├── length_est_bigru
│ │ ├── text_mot_match
│ │ └── VQVAEV3_CB1024_CMT_H1024_NRES3
│ ├── lit-llama
│ │ ├── 7B
│ │ │ └── lit-llama.pth
│ │ ├── 13B
│ │ └── tokenizer.model
│ └── t2m
│ ├── Comp_v6_KLD005
│ ├── M2T_EL4_DL4_NH8_PS
│ ├── T2M_Seq2Seq_NML1_Ear_SME0_N
│ ├── text_mot_match
│ └── VQVAEV3_CB1024_CMT_H1024_NRES3
├── body_models
│ └── smpl
│ ├── J_regressor_extra.npy
│ ├── kintree_table.pkl
│ ├── smplfaces.npy
│ └── SMPL_NEUTRAL.pkl
└── glove
├── our_vab_data.npy
├── our_vab_idx.pkl
└── our_vab_words.pkl
```
## 3. Pretrained Models
For pretrained VQ-VAE models
```
bash prepare/download_vqvae.sh
```
For finetuned LLaMA model
```
bash prepare/download_lora.sh
```
Once downloaded, you should have a folder like this:
```
MotionGPT/checkpoints
├── pretrained_vqvae
│ ├── kit.pth
│ └── t2m.pth
└── pretrained_lora
└── pretrained.pth
```
## 4. Dataset
Please follow [HumanML3D](https://github.com/EricGuo5513/HumanML3D) to download HumanML3D and KIT-ML datasets and put them under the directory `dataset` like:
```
MotionGPT/dataset
├── HumanML3D
└── KIT-ML
```
To prepare the dataset used for finetuning LLaMA, please follow the instructions below (take HumanML3D as an example)
```python
# Encode the motions to tokens by pretrianed VQ-VAE and save the token sequence results under `./dataset/HumanML3D/VQVAE/`
# For pretrained VQ-VAE, you can use the model provided or train the model by yourself following the training instruction.
python scripts/prepare_data.py --dataname t2m
# Generate the dataset on train split and validation split in the format of {instruction, input, output}
# Results saved as `./data/train.json` and `./data/val.json`
python scripts/generate_dataset.py --dataname t2m
# Generate corresponding instruction tuning dataset
# Results saved as `./data/train.pt` and `./data/val.pt`
python scripts/prepare_motion.py --dataname t2m
```
## Demo
Give task description (`--prompt`) and conditions (`--input`) to generate corresponding motion. The motion in `npy` format (`demo.npy`) and skeleton visualization result (`demo.gif`) will be saved under {output_dir}.
Please set `--render` if you want to render SMPL mesh.
```python
# text-to-motion
python generate_motion.py --prompt "Generate a sequence of motion tokens matching the following human motion description." --input "a person walks forward." --lora_path ./checkpoints/pretrained_lora/pretrained.pth --out_dir {output_dir} --render
# (text, init pose)-to-motion
python generate_motion.py --prompt "Generate a sequence of motion tokens matching the following human motion description given the initial token." --input "a person walks forward.<Motion Token>315</Motion Token>" --lora_path ./checkpoints/pretrained_lora/pretrained.pth --out_dir {output_dir} --render
# (text, last pose)-to-motion
python generate_motion.py --prompt "Generate a sequence of motion tokens matching the following human motion description given the last token." --input "a person walks forward.<Motion Token>406</Motion Token>" --lora_path ./checkpoints/pretrained_lora/pretrained.pth --out_dir {output_dir} --render
# (text, key poses)-to-motion
python generate_motion.py --prompt "Generate a sequence of motion tokens matching the following human motion description given several key tokens." --input "a person walks forward.<Motion Token>315,91,406</Motion Token>" --lora_path ./checkpoints/pretrained_lora/pretrained.pth --out_dir {output_dir} --render
```
## Train
For VQ-VAE training
```
python train_vqvae.py --out_dir {output_dir} --dataname t2m
```
For finetuning LLaMA with LoRA
```
python finetune_motion.py --out_dir {output_dir} --dataname t2m
```
## Evaluation
For VQ-VAE
```
python eval_vqvae.py --out_dir {output_dir} --resume_pth {vqvae_model_path} --dataname t2m
```
For LLaMA
```
python eval.py --vqvae_pth {vqvae_model_path} --lora_path {fintuned_model_path} --out_dir {output_dir} --dataname t2m
```
## Visualization
The generated poses are all saved in `npy` format with the shape of `[seq_len, joint_num, 3]`
The output results are saved under the same directory with the corresponding filename in `gif` format
For visualization in skeleton format
```python
# To visualize all the poses saved in {saved_pose_dir}
python visualization/plot_3d_global.py --dir {saved_pose_dir}
# To visualize selected poses in {saved_pose_dir}
python visualization/plot_3d_global.py --dir {saved_pose_dir} --motion-list {fname1} {fname2} ...
```
For SMPL mesh rendering
```python
# To visualize all the poses saved in {saved_pose_dir}
python visualization/render.py --dir {saved_pose_dir}
# To visualize selected poses in {saved_pose_dir}
python visualization/render.py --dir {saved_pose_dir} --motion-list {fname1} {fname2} ...
```
## Acknowledgement
Thanks to [HumanML3D](https://github.com/EricGuo5513/HumanML3D), [T2M-GPT](https://github.com/Mael-zys/T2M-GPT) and [Lit-LLaMA](https://github.com/Lightning-AI/lit-llama), our code is partially borrowing from them.
================================================
FILE: dataloader/eval_loader.py
================================================
import torch
from torch.utils import data
import numpy as np
from os.path import join as pjoin
import random
import codecs as cs
from tqdm import tqdm
import utils.paramUtil as paramUtil
from torch.utils.data._utils.collate import default_collate
from random import sample
def collate_fn(batch):
batch.sort(key=lambda x: x[3], reverse=True)
return default_collate(batch)
'''For use of training text-2-motion generative model'''
class Text2MotionDataset(data.Dataset):
def __init__(self, dataset_name, split, w_vectorizer, feat_bias = 5, max_text_len = 20, unit_length = 4):
self.max_length = 20
self.pointer = 0
self.dataset_name = dataset_name
self.max_text_len = max_text_len
self.unit_length = unit_length
self.w_vectorizer = w_vectorizer
if dataset_name == 't2m':
self.data_root = './dataset/HumanML3D'
self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
self.text_dir = pjoin(self.data_root, 'texts')
self.joints_num = 22
radius = 4
fps = 20
self.max_motion_length = 196
dim_pose = 263
kinematic_chain = paramUtil.t2m_kinematic_chain
self.meta_dir = './checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
elif dataset_name == 'kit':
self.data_root = './dataset/KIT-ML'
self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
self.text_dir = pjoin(self.data_root, 'texts')
self.joints_num = 21
radius = 240 * 8
fps = 12.5
dim_pose = 251
self.max_motion_length = 196
kinematic_chain = paramUtil.kit_kinematic_chain
self.meta_dir = './checkpoints/kit/Decomp_SP001_SM001_H512/meta'
mean = np.load(pjoin(self.meta_dir, 'mean.npy'))
std = np.load(pjoin(self.meta_dir, 'std.npy'))
split_file = pjoin(self.data_root, f'{split}.txt')
min_motion_len = 40 if self.dataset_name =='t2m' else 24
joints_num = self.joints_num
data_dict = {}
id_list = []
with cs.open(split_file, 'r') as f:
for line in f.readlines():
id_list.append(line.strip())
new_name_list = []
length_list = []
for name in tqdm(id_list):
try:
motion = np.load(pjoin(self.motion_dir, name + '.npy'))
if (len(motion)) < min_motion_len or (len(motion) >= 200):
continue
text_data = []
flag = False
with cs.open(pjoin(self.text_dir, name + '.txt')) as f:
for line in f.readlines():
text_dict = {}
line_split = line.strip().split('#')
caption = line_split[0]
tokens = line_split[1].split(' ')
f_tag = float(line_split[2])
to_tag = float(line_split[3])
f_tag = 0.0 if np.isnan(f_tag) else f_tag
to_tag = 0.0 if np.isnan(to_tag) else to_tag
text_dict['caption'] = caption
text_dict['tokens'] = tokens
if f_tag == 0.0 and to_tag == 0.0:
flag = True
text_data.append(text_dict)
else:
try:
n_motion = motion[int(f_tag*fps) : int(to_tag*fps)]
if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
continue
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
while new_name in data_dict:
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
data_dict[new_name] = {'motion': n_motion,
'length': len(n_motion),
'text':[text_dict]}
new_name_list.append(new_name)
length_list.append(len(n_motion))
except:
print(line_split)
print(line_split[2], line_split[3], f_tag, to_tag, name)
# break
if flag:
data_dict[name] = {'motion': motion,
'length': len(motion),
'text': text_data}
new_name_list.append(name)
length_list.append(len(motion))
except Exception as e:
# print(e)
pass
name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
self.mean = mean
self.std = std
self.length_arr = np.array(length_list)
self.data_dict = data_dict
self.name_list = name_list
self.reset_max_len(self.max_length)
def reset_max_len(self, length):
assert length <= self.max_motion_length
self.pointer = np.searchsorted(self.length_arr, length)
print("Pointer Pointing at %d"%self.pointer)
self.max_length = length
def inv_transform(self, data):
return data * self.std + self.mean
def forward_transform(self, data):
return (data - self.mean) / self.std
def __len__(self):
return len(self.data_dict) - self.pointer
def __getitem__(self, item):
idx = self.pointer + item
name = self.name_list[idx]
data = self.data_dict[name]
motion, m_length, text_list = data['motion'], data['length'], data['text']
# Randomly select a caption
text_data = random.choice(text_list)
caption, tokens = text_data['caption'], text_data['tokens']
if len(tokens) < self.max_text_len:
# pad with "unk"
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
sent_len = len(tokens)
tokens = tokens + ['unk/OTHER'] * (self.max_text_len + 2 - sent_len)
else:
# crop
tokens = tokens[:self.max_text_len]
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
sent_len = len(tokens)
pos_one_hots = []
word_embeddings = []
for token in tokens:
word_emb, pos_oh = self.w_vectorizer[token]
pos_one_hots.append(pos_oh[None, :])
word_embeddings.append(word_emb[None, :])
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
word_embeddings = np.concatenate(word_embeddings, axis=0)
if self.unit_length < 10:
coin2 = np.random.choice(['single', 'single', 'double'])
else:
coin2 = 'single'
if coin2 == 'double':
m_length = (m_length // self.unit_length - 1) * self.unit_length
elif coin2 == 'single':
m_length = (m_length // self.unit_length) * self.unit_length
idx = random.randint(0, len(motion) - m_length)
motion = motion[idx:idx+m_length]
"Z Normalization"
motion = (motion - self.mean) / self.std
if m_length < self.max_motion_length:
motion = np.concatenate([motion,
np.zeros((self.max_motion_length - m_length, motion.shape[1]))
], axis=0)
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens), name
def DATALoader(dataset_name, split,
batch_size, w_vectorizer,
num_workers = 8, unit_length = 4) :
val_loader = torch.utils.data.DataLoader(Text2MotionDataset(dataset_name, split, w_vectorizer, unit_length=unit_length),
batch_size,
shuffle = True,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last = True)
return val_loader
def cycle(iterable):
while True:
for x in iterable:
yield x
================================================
FILE: dataloader/tokenizer_loader.py
================================================
import torch
from torch.utils import data
import numpy as np
from os.path import join as pjoin
import random
import codecs as cs
from tqdm import tqdm
class VQMotionDataset(data.Dataset):
def __init__(self, dataset_name, feat_bias = 5, window_size = 64, unit_length = 8):
self.window_size = window_size
self.unit_length = unit_length
self.feat_bias = feat_bias
self.dataset_name = dataset_name
min_motion_len = 40 if dataset_name =='t2m' else 24
if dataset_name == 't2m':
self.data_root = './dataset/HumanML3D'
self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
self.text_dir = pjoin(self.data_root, 'texts')
self.joints_num = 22
radius = 4
fps = 20
self.max_motion_length = 196
dim_pose = 263
self.meta_dir = './checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
elif dataset_name == 'kit':
self.data_root = './dataset/KIT-ML'
self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
self.text_dir = pjoin(self.data_root, 'texts')
self.joints_num = 21
radius = 240 * 8
fps = 12.5
dim_pose = 251
self.max_motion_length = 196
self.meta_dir = './checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
joints_num = self.joints_num
mean = np.load(pjoin(self.meta_dir, 'mean.npy'))
std = np.load(pjoin(self.meta_dir, 'std.npy'))
split_file = pjoin(self.data_root, 'train_val.txt')
data_dict = {}
id_list = []
with cs.open(split_file, 'r') as f:
for line in f.readlines():
id_list.append(line.strip())
new_name_list = []
length_list = []
for name in tqdm(id_list):
try:
motion = np.load(pjoin(self.motion_dir, name + '.npy'))
if (len(motion)) < min_motion_len or (len(motion) >= 200):
continue
data_dict[name] = {'motion': motion,
'length': len(motion),
'name': name}
new_name_list.append(name)
length_list.append(len(motion))
except:
# Some motion may not exist in KIT dataset
pass
self.mean = mean
self.std = std
self.length_arr = np.array(length_list)
self.data_dict = data_dict
self.name_list = new_name_list
def inv_transform(self, data):
return data * self.std + self.mean
def __len__(self):
return len(self.data_dict)
def __getitem__(self, item):
name = self.name_list[item]
data = self.data_dict[name]
motion, m_length = data['motion'], data['length']
m_length = (m_length // self.unit_length) * self.unit_length
idx = random.randint(0, len(motion) - m_length)
motion = motion[idx:idx+m_length]
"Z Normalization"
motion = (motion - self.mean) / self.std
return motion, name
def DATALoader(dataset_name,
batch_size = 1,
num_workers = 8, unit_length = 4) :
train_loader = torch.utils.data.DataLoader(VQMotionDataset(dataset_name, unit_length=unit_length),
batch_size,
shuffle=True,
num_workers=num_workers,
drop_last = True)
return train_loader
def cycle(iterable):
while True:
for x in iterable:
yield x
================================================
FILE: dataloader/vqvae_loader.py
================================================
import torch
from torch.utils import data
import numpy as np
from os.path import join as pjoin
import random
import codecs as cs
from tqdm import tqdm
class VQMotionDataset(data.Dataset):
def __init__(self, dataset_name, window_size = 64, unit_length = 4):
self.window_size = window_size
self.unit_length = unit_length
self.dataset_name = dataset_name
if dataset_name == 't2m':
self.data_root = './dataset/HumanML3D'
self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
self.text_dir = pjoin(self.data_root, 'texts')
self.joints_num = 22
self.max_motion_length = 196
self.meta_dir = './checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
elif dataset_name == 'kit':
self.data_root = './dataset/KIT-ML'
self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
self.text_dir = pjoin(self.data_root, 'texts')
self.joints_num = 21
self.max_motion_length = 196
self.meta_dir = './checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
joints_num = self.joints_num
mean = np.load(pjoin(self.meta_dir, 'mean.npy'))
std = np.load(pjoin(self.meta_dir, 'std.npy'))
split_file = pjoin(self.data_root, 'train.txt')
self.data = []
self.lengths = []
id_list = []
with cs.open(split_file, 'r') as f:
for line in f.readlines():
id_list.append(line.strip())
for name in tqdm(id_list):
try:
motion = np.load(pjoin(self.motion_dir, name + '.npy'))
if motion.shape[0] < self.window_size:
continue
self.lengths.append(motion.shape[0] - self.window_size)
self.data.append(motion)
except:
# Some motion may not exist in KIT dataset
pass
self.mean = mean
self.std = std
print("Total number of motions {}".format(len(self.data)))
def inv_transform(self, data):
return data * self.std + self.mean
def compute_sampling_prob(self) :
prob = np.array(self.lengths, dtype=np.float32)
prob /= np.sum(prob)
return prob
def __len__(self):
return len(self.data)
def __getitem__(self, item):
motion = self.data[item]
idx = random.randint(0, len(motion) - self.window_size)
motion = motion[idx:idx+self.window_size]
"Z Normalization"
motion = (motion - self.mean) / self.std
return motion
def DATALoader(dataset_name,
batch_size,
num_workers = 8,
window_size = 64,
unit_length = 4):
train_loader = torch.utils.data.DataLoader(VQMotionDataset(dataset_name, unit_length=unit_length),
batch_size,
shuffle=True,
num_workers=num_workers,
drop_last = True)
return train_loader
def cycle(iterable):
while True:
for x in iterable:
yield x
================================================
FILE: environment.yml
================================================
name: motiongpt
channels:
- menpo
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- brotli-python=1.0.9=py39h5a03fae_7
- bzip2=1.0.8=h7f98852_4
- ca-certificates=2023.5.7=hbcca054_0
- certifi=2023.5.7=pyhd8ed1ab_0
- dataclasses=0.8=pyhc8e2a94_3
- ffmpeg=4.3.2=hca11adc_0
- freetype=2.10.4=h0708190_1
- freetype-py=2.4.0=pyhd8ed1ab_0
- future=0.18.3=pyhd8ed1ab_0
- geos=3.10.2=h9c3ff4c_0
- gmp=6.2.1=h58526e2_0
- gnutls=3.6.13=h85f3911_1
- h5py=3.7.0=py39h737f45e_0
- hdf5=1.10.6=h3ffc7dd_1
- idna=3.4=pyhd8ed1ab_0
- intel-openmp=2023.1.0=hdb19cb5_46305
- jbig=2.1=h7f98852_2003
- jpeg=9e=h166bdaf_1
- lame=3.100=h7f98852_1001
- lcms2=2.12=hddcbb42_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=2.2.1=h9c3ff4c_0
- libdeflate=1.7=h7f98852_5
- libffi=3.4.2=h6a678d5_6
- libgcc-ng=11.2.0=h1234567_1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libpng=1.6.37=h21135ba_2
- libstdcxx-ng=11.2.0=h1234567_1
- libtiff=4.3.0=hf544144_1
- libwebp-base=1.2.2=h7f98852_1
- lz4-c=1.9.3=h9c3ff4c_1
- mapbox_earcut=1.0.0=py39hf939315_3
- mkl=2023.1.0=h6d00ec8_46342
- mkl-service=2.4.0=py39h5eee18b_1
- mkl_fft=1.3.6=py39h417a72b_1
- mkl_random=1.2.2=py39h417a72b_1
- ncurses=6.4=h6a678d5_0
- nettle=3.6=he412f7d_0
- networkx=3.1=pyhd8ed1ab_0
- olefile=0.46=pyh9f0ad1d_1
- openh264=2.1.1=h780b84a_0
- openjpeg=2.4.0=hb52868f_1
- openssl=1.1.1u=h7f8727e_0
- osmesa=12.2.2.dev=0
- packaging=23.1=pyhd8ed1ab_0
- pip=23.0.1=py39h06a4308_0
- pooch=1.7.0=pyha770c72_3
- pyglet=1.5.27=py39hf3d152e_3
- pyopengl=3.1.6=pyhd8ed1ab_1
- pyrender=0.1.45=pyh8a188c0_3
- pysocks=1.7.1=pyha2e5f31_6
- python=3.9.16=h7a1cb2a_2
- python_abi=3.9=2_cp39
- readline=8.2=h5eee18b_0
- setuptools=66.0.0=py39h06a4308_0
- shapely=1.8.2=py39h73b9895_1
- six=1.16.0=pyh6c4a22f_0
- sqlite=3.41.2=h5eee18b_0
- tbb=2021.8.0=hdb19cb5_0
- tk=8.6.12=h1ccaba5_0
- trimesh=3.22.3=pyhd8ed1ab_0
- typing_extensions=4.7.1=pyha770c72_0
- wheel=0.38.4=py39h06a4308_0
- x264=1!161.3030=h7f98852_1
- xz=5.2.10=h5eee18b_1
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.0=ha95c52a_0
- pip:
- absl-py==1.4.0
- accelerate==0.18.0
- aiobotocore==2.5.0
- aiohttp==3.8.4
- aioitertools==0.11.0
- aiosignal==1.3.1
- anyio==3.6.2
- argon2-cffi==21.3.0
- argon2-cffi-bindings==21.2.0
- arrow==1.2.3
- asttokens==2.2.1
- async-timeout==4.0.2
- attrs==23.1.0
- backcall==0.2.0
- beautifulsoup4==4.12.2
- bitsandbytes==0.38.1
- bleach==6.0.0
- blessed==1.20.0
- botocore==1.29.76
- cachetools==5.3.0
- cffi==1.15.1
- charset-normalizer==3.1.0
- chumpy==0.70
- click==8.1.3
- cmake==3.26.3
- comm==0.1.3
- croniter==1.3.14
- cycler==0.11.0
- datasets==2.11.0
- dateutils==0.6.12
- debugpy==1.6.7
- decorator==5.1.1
- deepdiff==6.3.0
- defusedxml==0.7.1
- dill==0.3.6
- docstring-parser==0.15
- executing==1.2.0
- fairscale==0.4.13
- fastapi==0.88.0
- fastjsonschema==2.16.3
- filelock==3.12.0
- fire==0.5.0
- fqdn==1.5.1
- frozenlist==1.3.3
- fsspec==2023.4.0
- ftfy==6.1.1
- google-auth==2.17.3
- google-auth-oauthlib==1.0.0
- grpcio==1.54.0
- h11==0.14.0
- hiq-python==1.1.12
- huggingface-hub==0.13.4
- imageio==2.9.0
- importlib-metadata==6.5.0
- importlib-resources==5.12.0
- inquirer==3.1.3
- ipykernel==6.22.0
- ipython==8.13.2
- ipython-genutils==0.2.0
- ipywidgets==8.0.6
- isoduration==20.11.0
- itsdangerous==2.1.2
- jedi==0.18.2
- jinja2==3.1.2
- jmespath==1.0.1
- jsonargparse==4.20.1
- jsonpointer==2.3
- jsonschema==4.17.3
- jupyter==1.0.0
- jupyter-client==8.2.0
- jupyter-console==6.6.3
- jupyter-core==5.3.0
- jupyter-events==0.6.3
- jupyter-server==2.5.0
- jupyter-server-terminals==0.4.4
- jupyterlab-pygments==0.2.2
- jupyterlab-widgets==3.0.7
- kiwisolver==1.4.4
- lightning==2.0.0
- lightning-cloud==0.5.33
- lightning-fabric==2.0.1.post0
- lightning-utilities==0.8.0
- lit==16.0.1
- markdown==3.4.3
- markdown-it-py==2.2.0
- markupsafe==2.1.2
- matplotlib==3.4.3
- matplotlib-inline==0.1.6
- mdurl==0.1.2
- mistune==2.0.5
- mpmath==1.3.0
- multidict==6.0.4
- multiprocess==0.70.14
- nbclassic==1.0.0
- nbclient==0.7.4
- nbconvert==7.3.1
- nbformat==5.8.0
- nest-asyncio==1.5.6
- notebook==6.5.4
- notebook-shim==0.2.3
- numpy==1.24.2
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cuda-cupti-cu11==11.7.101
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cudnn-cu11==8.5.0.96
- nvidia-cufft-cu11==10.9.0.58
- nvidia-curand-cu11==10.2.10.91
- nvidia-cusolver-cu11==11.4.0.1
- nvidia-cusparse-cu11==11.7.4.91
- nvidia-nccl-cu11==2.14.3
- nvidia-nvtx-cu11==11.7.91
- oauthlib==3.2.2
- ordered-set==4.1.0
- pandas==2.0.0
- pandocfilters==1.5.0
- parso==0.8.3
- pexpect==4.8.0
- pickleshare==0.7.5
- pillow==9.5.0
- platformdirs==3.5.0
- prometheus-client==0.16.0
- prompt-toolkit==3.0.38
- protobuf==3.20.0
- psutil==5.9.5
- ptyprocess==0.7.0
- pure-eval==0.2.2
- py-itree==0.0.19
- pyarrow==11.0.0
- pyasn1==0.5.0
- pyasn1-modules==0.3.0
- pycparser==2.21
- pydantic==1.10.7
- pygments==2.15.1
- pyjwt==2.6.0
- pyllama==0.0.9
- pyparsing==3.0.9
- pyrsistent==0.19.3
- python-dateutil==2.8.2
- python-editor==1.0.4
- python-json-logger==2.0.7
- python-multipart==0.0.6
- pytorch-lightning==2.0.1.post0
- pytz==2023.3
- pyyaml==6.0
- pyzmq==25.0.2
- qtconsole==5.4.2
- qtpy==2.3.1
- readchar==4.0.5
- regex==2023.3.23
- requests==2.28.2
- requests-oauthlib==1.3.1
- responses==0.18.0
- rfc3339-validator==0.1.4
- rfc3986-validator==0.1.1
- rich==13.3.4
- rsa==4.9
- s3fs==2023.4.0
- scipy==1.10.1
- send2trash==1.8.2
- sentencepiece==0.1.97
- smplx==0.1.28
- sniffio==1.3.0
- soupsieve==2.4.1
- stack-data==0.6.2
- starlette==0.22.0
- starsessions==1.3.0
- sympy==1.11.1
- tensorboard==2.12.2
- tensorboard-data-server==0.7.0
- tensorboard-plugin-wit==1.8.1
- termcolor==2.2.0
- terminado==0.17.1
- tinycss2==1.2.1
- tokenizers==0.13.3
- torch==2.0.0
- torchmetrics==0.11.4
- torchvision==0.15.1
- tornado==6.3.1
- tqdm==4.65.0
- traitlets==5.9.0
- transformers==4.28.1
- triton==2.0.0
- typeshed-client==2.2.0
- typing-extensions==4.5.0
- tzdata==2023.3
- uri-template==1.2.0
- urllib3==1.26.15
- uvicorn==0.21.1
- wcwidth==0.2.6
- webcolors==1.13
- webencodings==0.5.1
- websocket-client==1.5.1
- websockets==11.0.2
- werkzeug==2.2.3
- widgetsnbextension==4.0.7
- wrapt==1.15.0
- xxhash==3.2.0
- yarl==1.8.2
- zipp==3.15.0
- zstandard==0.21.0
================================================
FILE: eval.py
================================================
import os
import torch
import numpy as np
import json
import clip
from options import option
import models.vqvae as vqvae
import utils.utils_model as utils_model
from utils.evaluate import evaluation
from dataloader.eval_loader import DATALoader
from options.get_eval_option import get_opt
from models.evaluator_wrapper import EvaluatorModelWrapper
import warnings
warnings.filterwarnings('ignore')
import sys
import time
from pathlib import Path
from typing import Optional
import lightning as L
import torch
from lit_llama import LLaMA, LLaMAConfig
from lit_llama.lora import lora
from lit_llama.utils import EmptyInitOnDevice, lazy_load
from lit_llama.tokenizer import Tokenizer
args = option.get_args_parser()
def main(
quantize: Optional[str] = None,
dtype: str = "float32",
accelerator: str = "auto"
) -> None:
os.makedirs(args.out_dir, exist_ok = True)
##### ---- Logger ---- #####
logger = utils_model.get_logger(args.out_dir)
logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
from utils.word_vectorizer import WordVectorizer
w_vectorizer = WordVectorizer('./glove', 'our_vab')
val_loader = DATALoader(args.dataname, 'test', 32, w_vectorizer, unit_length=2**args.down_t)
if args.dataname == 'kit' :
dataset_opt_path = './checkpoints/kit/Comp_v6_KLD005/opt.txt'
args.nb_joints = 21
else :
dataset_opt_path = './checkpoints/t2m/Comp_v6_KLD005/opt.txt'
args.nb_joints = 22
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
##### ---- Network ---- #####
## load clip model and datasets
clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cuda'), jit=False) # Must set jit=False for training
clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16
clip_model.eval()
for p in clip_model.parameters():
p.requires_grad = False
print('Loading VAE')
vae = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
512,
args.code_dim,
args.output_emb_width,
2,
args.stride_t,
args.width,
3,
args.dilation_growth_rate)
resume_pth = f"./checkpoints/pretrained_vqvae/{args.dataname}.pth"
ckpt = torch.load(resume_pth, map_location='cpu')
vae.load_state_dict(ckpt['net'], strict=True)
vae = vae.cuda().eval()
print('Loading VAE Done')
lora_path = Path(args.lora_path)
print('Load finetuned model from:', lora_path)
pretrained_path = Path(f"./checkpoints/lit-llama/{args.pretrained_llama}/lit-llama.pth")
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
assert lora_path.is_file()
assert pretrained_path.is_file()
assert tokenizer_path.is_file()
if quantize is not None:
raise NotImplementedError("Quantization in LoRA is not supported yet")
fabric = L.Fabric(accelerator=accelerator, devices=1)
dt = getattr(torch, dtype, None)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{dtype} is not a valid dtype.")
dtype = dt
print("Loading model ...", file=sys.stderr)
t0 = time.time()
with EmptyInitOnDevice(
device=fabric.device, dtype=dtype, quantization_mode=quantize
), lora(r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout, enabled=True):
# model = LLaMA(LLaMAConfig()) # TODO: Support different model sizes
config = LLaMAConfig.from_name(args.pretrained_llama)
model = LLaMA(config)
# 1. Load the pretrained weights
pretrained_checkpoint = lazy_load(pretrained_path)
model.load_state_dict(pretrained_checkpoint, strict=False)
# 2. Load the fine-tuned LoRA weights
lora_checkpoint = lazy_load(lora_path)
model.load_state_dict(lora_checkpoint, strict=False)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
model.eval()
model = fabric.setup_module(model)
tokenizer = Tokenizer(tokenizer_path)
fid = []
div = []
top1 = []
top2 = []
top3 = []
matching = []
repeat_time = 3
for _ in range(repeat_time):
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, logger = evaluation(val_loader, vae, model, logger, tokenizer, eval_wrapper=eval_wrapper, instruction=args.prompt)
fid.append(best_fid)
div.append(best_div)
top1.append(best_top1)
top2.append(best_top2)
top3.append(best_top3)
matching.append(best_matching)
print('final result:')
print('fid: ', sum(fid)/repeat_time)
print('div: ', sum(div)/repeat_time)
print('top1: ', sum(top1)/repeat_time)
print('top2: ', sum(top2)/repeat_time)
print('top3: ', sum(top3)/repeat_time)
print('matching: ', sum(matching)/repeat_time)
fid = np.array(fid)
div = np.array(div)
top1 = np.array(top1)
top2 = np.array(top2)
top3 = np.array(top3)
matching = np.array(matching)
msg_final = f"FID. {np.mean(fid):.3f}, conf. {np.std(fid)*1.96/np.sqrt(repeat_time):.3f}, Diversity. {np.mean(div):.3f}, conf. {np.std(div)*1.96/np.sqrt(repeat_time):.3f}, TOP1. {np.mean(top1):.3f}, conf. {np.std(top1)*1.96/np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2)*1.96/np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3)*1.96/np.sqrt(repeat_time):.3f}, Matching. {np.mean(matching):.3f}, conf. {np.std(matching)*1.96/np.sqrt(repeat_time):.3f}"
logger.info(msg_final)
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
warnings.filterwarnings(
# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
"ignore",
message="ComplexHalf support is experimental and many operators don't support it yet"
)
main()
================================================
FILE: eval_vqvae.py
================================================
# This code is based on https://github.com/Mael-zys/T2M-GPT.git
import os
import json
import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import models.vqvae as vqvae
import options.option_vqvae as option_vq
import utils.utils_model as utils_model
from dataloader.eval_loader import DATALoader
from utils.evaluate import vqvae_evaluation
from options.get_eval_option import get_opt
from models.evaluator_wrapper import EvaluatorModelWrapper
from utils.word_vectorizer import WordVectorizer
import warnings
warnings.filterwarnings('ignore')
import numpy as np
args = option_vq.get_args_parser()
torch.manual_seed(args.seed)
os.makedirs(args.out_dir, exist_ok = True)
def main():
logger = utils_model.get_logger(args.out_dir)
writer = SummaryWriter(args.out_dir)
logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
w_vectorizer = WordVectorizer('./glove', 'our_vab')
dataset_opt_path = './checkpoints/kit/Comp_v6_KLD005/opt.txt' if args.dataname == 'kit' else './checkpoints/t2m/Comp_v6_KLD005/opt.txt'
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
args.nb_joints = 21 if args.dataname == 'kit' else 22
val_loader = DATALoader(args.dataname, 'test', 32, w_vectorizer, unit_length=2**args.down_t)
net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
args.nb_code,
args.code_dim,
args.output_emb_width,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate,
args.vq_act,
args.vq_norm)
if args.resume_pth :
logger.info('loading checkpoint from {}'.format(args.resume_pth))
ckpt = torch.load(args.resume_pth, map_location='cpu')
net.load_state_dict(ckpt['net'], strict=True)
net.train()
net.cuda()
fid = []
div = []
top1 = []
top2 = []
top3 = []
matching = []
repeat_time = 20
for _ in range(repeat_time):
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = vqvae_evaluation(args.out_dir, val_loader, net, logger, writer, eval_wrapper, 0)
fid.append(best_fid)
div.append(best_div)
top1.append(best_top1)
top2.append(best_top2)
top3.append(best_top3)
matching.append(best_matching)
print('final result:')
print('fid: ', sum(fid)/repeat_time)
print('div: ', sum(div)/repeat_time)
print('top1: ', sum(top1)/repeat_time)
print('top2: ', sum(top2)/repeat_time)
print('top3: ', sum(top3)/repeat_time)
print('matching: ', sum(matching)/repeat_time)
fid = np.array(fid)
div = np.array(div)
top1 = np.array(top1)
top2 = np.array(top2)
top3 = np.array(top3)
matching = np.array(matching)
msg_final = f"FID. {np.mean(fid):.3f}, conf. {np.std(fid)*1.96/np.sqrt(repeat_time):.3f}, Diversity. {np.mean(div):.3f}, conf. {np.std(div)*1.96/np.sqrt(repeat_time):.3f}, TOP1. {np.mean(top1):.3f}, conf. {np.std(top1)*1.96/np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2)*1.96/np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3)*1.96/np.sqrt(repeat_time):.3f}, Matching. {np.mean(matching):.3f}, conf. {np.std(matching)*1.96/np.sqrt(repeat_time):.3f}"
logger.info(msg_final)
if __name__ == '__main__':
main()
================================================
FILE: finetune_motion.py
================================================
import os
import time
import lightning as L
import numpy as np
import torch
import clip
from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict
from lit_llama.model import LLaMA, LLaMAConfig
from lit_llama.tokenizer import Tokenizer
from dataloader.eval_loader import DATALoader
from utils.evaluate import evaluation
from utils.word_vectorizer import WordVectorizer
from options.get_eval_option import get_opt
from models.evaluator_wrapper import EvaluatorModelWrapper
import models.vqvae as vqvae
from options import option
import utils.utils_model as utils_model
from torch.utils.tensorboard import SummaryWriter
import json
args = option.get_args_parser()
gradient_accumulation_steps = args.batch_size // args.micro_batch_size
max_iters = 50000 * 3 // args.micro_batch_size
def main():
fabric = L.Fabric(accelerator="cuda", devices=1, precision="bf16-mixed")
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)
if fabric.global_rank == 0:
os.makedirs(args.out_dir, exist_ok=True)
train_data, val_data = load_datasets()
w_vectorizer = WordVectorizer('./glove', 'our_vab')
val_loader = DATALoader(args.dataname, 'val', 32, w_vectorizer, unit_length=2**args.down_t)
if args.dataname == 'kit' :
dataset_opt_path = './checkpoints/kit/Comp_v6_KLD005/opt.txt'
args.nb_joints = 21
else :
dataset_opt_path = './checkpoints/t2m/Comp_v6_KLD005/opt.txt'
args.nb_joints = 22
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
logger = utils_model.get_logger(args.out_dir)
writer = SummaryWriter(args.out_dir)
logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
args.nb_code,
args.code_dim,
args.output_emb_width,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate)
print ('loading checkpoint from {}'.format(args.vqvae_pth))
ckpt = torch.load(args.vqvae_pth, map_location='cpu')
net.load_state_dict(ckpt['net'], strict=True)
net.eval()
net.cuda()
clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cuda'), jit=False) # Must set jit=False for training
clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16
clip_model.eval()
for p in clip_model.parameters():
p.requires_grad = False
config = LLaMAConfig.from_name(args.pretrained_llama)
config.block_size = args.block_size
checkpoint = torch.load(f"./checkpoints/lit-llama/{args.pretrained_llama}/lit-llama.pth")
tokenizer = Tokenizer("./checkpoints/lit-llama/tokenizer.model")
with fabric.device, lora(r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout, enabled=True):
torch.set_default_tensor_type(torch.HalfTensor)
model = LLaMA(config).bfloat16()
torch.set_default_tensor_type(torch.FloatTensor)
# strict=False because missing keys due to LoRA weights not contained in checkpoint state
model.load_state_dict(checkpoint, strict=False)
if args.resume_pth:
checkpoint = torch.load(args.resume_pth)
model.load_state_dict(checkpoint, strict=False)
mark_only_lora_as_trainable(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
model, optimizer = fabric.setup(model, optimizer)
train(fabric, model, optimizer, train_data, val_data, args.out_dir, logger, writer)
# Save the final LoRA checkpoint at the end of training
checkpoint = lora_state_dict(model)
fabric.save(os.path.join(args.out_dir, "lit-llama-lora-finetuned.pth"), checkpoint)
# Evaluation on validation set
evaluation(val_loader, net, model, logger, tokenizer, eval_wrapper=eval_wrapper, instruction=args.prompt)
def train(
fabric: L.Fabric,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
train_data: np.ndarray,
val_data: np.ndarray,
out_dir: str,
logger,
writer
) -> None:
"""The training loop.
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
step_count = 0
for iter_num in range(max_iters):
if step_count <= args.warmup_steps:
# linear warmup
lr = args.learning_rate * step_count / args.warmup_steps
for param_group in optimizer.param_groups:
param_group['lr'] = lr
t0 = time.time()
input_ids, targets = get_batch(fabric, train_data)
logits = model(input_ids)
loss = loss_fn(logits, targets)
fabric.backward(loss)
if (iter_num + 1) % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
step_count += 1
if step_count % args.eval_interval == 0:
val_loss = validate(fabric, model, val_data)
writer.add_scalar('./Val', val_loss, step_count)
logger.info(f"step {iter_num}: val loss {val_loss:.4f}")
fabric.barrier()
if step_count % args.save_interval == 0:
print(f"Saving LoRA weights to {out_dir}")
# We are only saving the LoRA weights
# TODO: Provide a function/script to merge the LoRA weights with pretrained weights
checkpoint = lora_state_dict(model)
fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"), checkpoint)
dt = time.time() - t0
if iter_num % args.log_interval == 0:
writer.add_scalar('./Train', loss, iter_num)
logger.info(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
@torch.no_grad()
def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(args.eval_iters)
for k in range(args.eval_iters):
input_ids, targets = get_batch(fabric, val_data)
logits = model(input_ids)
loss = loss_fn(logits, targets)
losses[k] = loss.item()
out = losses.mean()
model.train()
return out.item()
def loss_fn(logits, targets):
# shift the targets such that output n predicts token n+1
logits = logits[..., :-1, :].contiguous()
targets = targets[..., 1:].contiguous()
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return loss
def get_batch(fabric: L.Fabric, data: list):
ix = torch.randint(len(data), (args.micro_batch_size,))
input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
labels = [data[i]["labels"].type(torch.int64) for i in ix]
max_len = max(len(s) for s in input_ids)
def pad_left(x, pad_id):
# pad right based on the longest sequence
n = max_len - len(x)
return torch.cat((torch.full((n,), pad_id, dtype=x.dtype), x))
# def pad_right(x, pad_id):
# # pad right based on the longest sequence
# n = max_len - len(x)
# return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
x = torch.stack([pad_left(x, pad_id=0) for x in input_ids])
y = torch.stack([pad_left(x, pad_id=-1) for x in labels])
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
return x, y
def load_datasets():
print('Load data from:', args.data_dir)
train_data = torch.load(os.path.join(args.data_dir, "train.pt"))
val_data = torch.load(os.path.join(args.data_dir, "val.pt"))
return train_data, val_data
if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")
# from jsonargparse.cli import CLI
# args = option_trans.get_args_parser()
# args.dataname = 't2m'
# args.out_dir = 'out/lora/mydataset_v3'
# logger = utils_model.get_logger(args.out_dir)
# writer = SummaryWriter(args.out_dir)
# logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
# CLI(main)
main()
================================================
FILE: generate.py
================================================
import sys
import time
import warnings
from pathlib import Path
from typing import Optional
import lightning as L
import torch
from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import EmptyInitOnDevice, lazy_load
@torch.no_grad()
def generate(
model: torch.nn.Module,
idx: torch.Tensor,
max_new_tokens: int,
max_seq_length: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Args:
model: The model to use.
idx: Tensor of shape (T) with indices of the prompt sequence.
max_new_tokens: The number of new tokens to generate.
max_seq_length: The maximum sequence length allowed.
temperature: Scales the predicted logits by 1 / temperature
top_k: If specified, only sample among the tokens with the k highest probabilities
eos_id: If specified, stop generating any more token once the <eos> token is triggered
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = idx.size(0)
T_new = T + max_new_tokens
empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
empty[:T] = idx
idx = empty
# generate max_new_tokens tokens
for t in range(T, T_new):
# ignore the not-filled-yet tokens
idx_cond = idx[:t]
# if the sequence context is growing too long we must crop it at max_seq_length
idx_cond = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:]
# forward
logits = model(idx_cond.view(1, -1))
logits = logits[0, -1] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[[-1]]] = -float("Inf")
probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# concatenate the new generation
idx[t] = idx_next
# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[:t + 1] # include the EOS token
return idx
def main(
prompt: str = "Hello, my name is",
*,
num_samples: int = 1,
max_new_tokens: int = 50,
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Optional[Path] = None,
tokenizer_path: Optional[Path] = None,
model_size: str = "7B",
quantize: Optional[str] = None,
) -> None:
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
Args:
prompt: The prompt string to use for generating the samples.
num_samples: The number of text samples to generate.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
checkpoint_path: The checkpoint path to load.
tokenizer_path: The tokenizer path to load.
model_size: The model size to load.
quantize: Whether to quantize the model and using which method:
``"llm.int8"``: LLM.int8() mode,
``"gptq.int4"``: GPTQ 4-bit mode.
"""
if not checkpoint_path:
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth")
if not tokenizer_path:
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
assert checkpoint_path.is_file(), checkpoint_path
assert tokenizer_path.is_file(), tokenizer_path
fabric = L.Fabric(accelerator="cuda", devices=1)
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
print("Loading model ...", file=sys.stderr)
t0 = time.time()
with EmptyInitOnDevice(
device=fabric.device, dtype=dtype, quantization_mode=quantize
):
model = LLaMA.from_name(model_size)
checkpoint = lazy_load(checkpoint_path)
model.load_state_dict(checkpoint)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
model.eval()
model = fabric.setup_module(model)
tokenizer = Tokenizer(tokenizer_path)
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
L.seed_everything(1234)
t0 = time.perf_counter()
for _ in range(num_samples):
y = generate(
model,
encoded_prompt,
max_new_tokens,
model.config.block_size, # type: ignore[union-attr,arg-type]
temperature=temperature,
top_k=top_k,
)
print(tokenizer.decode(y))
t = time.perf_counter() - t0
print(f"\n\nTime for inference: {t:.02f} sec total, {num_samples * max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
if __name__ == "__main__":
from jsonargparse import CLI
torch.set_float32_matmul_precision("high")
warnings.filterwarnings(
# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
"ignore",
message="ComplexHalf support is experimental and many operators don't support it yet"
)
warnings.filterwarnings(
# Triggered in bitsandbytes/autograd/_functions.py:298
"ignore",
message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
)
CLI(main)
================================================
FILE: generate_batch.py
================================================
import sys
import time
import warnings
from pathlib import Path
from typing import Optional
import lightning as L
import torch
from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import EmptyInitOnDevice, lazy_load
@torch.no_grad()
def generate(
model: torch.nn.Module,
idx: torch.Tensor,
max_new_tokens: int,
max_seq_length: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Args:
model: The model to use.
idx: Tensor of shape (B, T) with indices of the prompt sequence.
max_new_tokens: The number of new tokens to generate.
max_seq_length: The maximum sequence length allowed.
temperature: Scales the predicted logits by 1 / temperature
top_k: If specified, only sample among the tokens with the k highest probabilities
eos_id: If specified, stop generating any more token once the <eos> token is triggered
"""
# create an empty tensor of the expected final shape and fill in the current tokens
B, T = idx.shape
T_new = T + max_new_tokens
empty = torch.empty(B, T_new, dtype=idx.dtype, device=idx.device)
empty[:, :T] = idx
idx = empty
# generate max_new_tokens tokens
for t in range(T, T_new):
# ignore the not-filled-yet tokens
idx_cond = idx[:, :t]
# if the sequence context is growing too long we must crop it at max_seq_length
idx_cond = idx_cond if T <= max_seq_length else idx_cond[:, -max_seq_length:]
# forward
logits = model(idx_cond)
logits = logits[:, -1] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# concatenate the new column
idx[:, t:] = idx_next
# if <eos> token is triggered, return the output (stop generation)
# if idx_next == eos_id:
# return idx[:t + 1] # include the EOS token
return idx
def main(
prompt: str = "Hello, my name is",
*,
num_samples: int = 1,
max_new_tokens: int = 50,
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Optional[Path] = None,
tokenizer_path: Optional[Path] = None,
model_size: str = "7B",
quantize: Optional[str] = None,
) -> None:
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
Args:
prompt: The prompt string to use for generating the samples.
num_samples: The number of text samples to generate.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
checkpoint_path: The checkpoint path to load.
tokenizer_path: The tokenizer path to load.
model_size: The model size to load.
quantize: Whether to quantize the model and using which method:
``"llm.int8"``: LLM.int8() mode,
``"gptq.int4"``: GPTQ 4-bit mode.
"""
if not checkpoint_path:
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth")
if not tokenizer_path:
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
assert checkpoint_path.is_file(), checkpoint_path
assert tokenizer_path.is_file(), tokenizer_path
fabric = L.Fabric(accelerator="cuda", devices=1)
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
print("Loading model ...", file=sys.stderr)
t0 = time.time()
with EmptyInitOnDevice(
device=fabric.device, dtype=dtype, quantization_mode=quantize
):
model = LLaMA.from_name(model_size)
checkpoint = lazy_load(checkpoint_path)
model.load_state_dict(checkpoint)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
model.eval()
model = fabric.setup_module(model)
tokenizer = Tokenizer(tokenizer_path)
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
encoded_prompt = encoded_prompt[None, :] # add batch dimension
L.seed_everything(1234)
t0 = time.perf_counter()
for _ in range(num_samples):
y = generate(
model,
encoded_prompt,
max_new_tokens,
model.config.block_size, # type: ignore[union-attr,arg-type]
temperature=temperature,
top_k=top_k,
)[0] # unpack batch dimension
print(tokenizer.decode(y))
t = time.perf_counter() - t0
print(f"\n\nTime for inference: {t:.02f} sec total, {num_samples * max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
if __name__ == "__main__":
from jsonargparse import CLI
torch.set_float32_matmul_precision("high")
warnings.filterwarnings(
# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
"ignore",
message="ComplexHalf support is experimental and many operators don't support it yet"
)
warnings.filterwarnings(
# Triggered in bitsandbytes/autograd/_functions.py:298
"ignore",
message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",
)
CLI(main)
================================================
FILE: generate_motion.py
================================================
import os
import sys
import time
import warnings
from pathlib import Path
from typing import Optional
import lightning as L
import torch
import numpy as np
import models.vqvae as vqvae
from generate import generate
from lit_llama import Tokenizer, LLaMA, LLaMAConfig
from lit_llama.lora import lora
from lit_llama.utils import EmptyInitOnDevice, lazy_load
from scripts.prepare_motion import generate_prompt
from options import option
import imageio
from utils.evaluate import plot
from visualization.render import render
warnings.filterwarnings('ignore')
args = option.get_args_parser()
def main(
quantize: Optional[str] = None,
dtype: str = "float32",
max_new_tokens: int = 200,
top_k: int = 200,
temperature: float = 0.8,
accelerator: str = "auto",
) -> None:
lora_path = Path(args.lora_path)
pretrained_path = Path(f"./checkpoints/lit-llama/{args.pretrained_llama}/lit-llama.pth")
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
assert lora_path.is_file()
assert pretrained_path.is_file()
assert tokenizer_path.is_file()
if quantize is not None:
raise NotImplementedError("Quantization in LoRA is not supported yet")
fabric = L.Fabric(accelerator=accelerator, devices=1)
dt = getattr(torch, dtype, None)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{dtype} is not a valid dtype.")
dtype = dt
net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
args.nb_code,
args.code_dim,
args.output_emb_width,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate)
print ('loading checkpoint from {}'.format(args.vqvae_pth))
ckpt = torch.load(args.vqvae_pth, map_location='cpu')
net.load_state_dict(ckpt['net'], strict=True)
net.eval()
net.cuda()
print("Loading model ...", file=sys.stderr)
t0 = time.time()
with EmptyInitOnDevice(
device=fabric.device, dtype=dtype, quantization_mode=quantize
), lora(r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout, enabled=True):
config = LLaMAConfig.from_name(args.pretrained_llama)
model = LLaMA(config)
# model = LLaMA(LLaMAConfig()) # TODO: Support different model sizes
# 1. Load the pretrained weights
pretrained_checkpoint = lazy_load(pretrained_path)
model.load_state_dict(pretrained_checkpoint, strict=False)
# 2. Load the fine-tuned LoRA weights
lora_checkpoint = lazy_load(lora_path)
model.load_state_dict(lora_checkpoint, strict=False)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
model.eval()
model = fabric.setup_module(model)
tokenizer = Tokenizer(tokenizer_path)
sample = {"instruction": args.prompt, "input": args.input}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
t0 = time.perf_counter()
output = generate(
model,
idx=encoded,
max_seq_length=max_new_tokens,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id
)
output = tokenizer.decode(output)
output = output.split("### Response:")[1].strip()
t = time.perf_counter() - t0
print(f"\n\nTime for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
tokens = torch.tensor([int(token) for token in output.split(',')]).cuda()
generated_pose, img = plot(tokens, net, args.dataname)
os.makedirs(args.out_dir, exist_ok=True)
np.save(os.path.join(args.out_dir, 'demo.npy'), generated_pose)
imageio.mimsave(os.path.join(args.out_dir, 'demo.gif'), np.array(img), fps=20)
if args.render:
print("Rendering...")
render(generated_pose, 'demo', outdir=args.out_dir)
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
main()
================================================
FILE: index.html
================================================
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="google-site-verification" content="oZMIcPh6afVajpq9eSwxoKM79HITHoE3mZ46IXmt6D8" />
<meta name="description"
content="MotionGPT: Finetuned LLMs are General-Purpose Motion Generators.">
<meta name="keywords" content="MotionGPT, Motion Generation, LLM">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>MotionGPT: Finetuned LLMs are General-Purpose Motion Generators</title>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-PYVRSFMDRL"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-PYVRSFMDRL');
</script>
<link href="https://fonts.googleapis.com/css?family=Google+Sans|Noto+Sans|Castoro"
rel="stylesheet">
<link rel="stylesheet" href="./static/css/bulma.min.css">
<link rel="stylesheet" href="./static/css/bulma-carousel.min.css">
<link rel="stylesheet" href="./static/css/bulma-slider.min.css">
<link rel="stylesheet" href="./static/css/fontawesome.all.min.css">
<link rel="stylesheet"
href="https://cdn.jsdelivr.net/gh/jpswalsh/academicons@1/css/academicons.min.css">
<link rel="stylesheet" href="./static/css/index.css">
<link rel="icon" href="./static/images/dancing-motion-svgrepo-com.svg">
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<script defer src="./static/js/fontawesome.all.min.js"></script>
<script src="./static/js/bulma-carousel.min.js"></script>
<script src="./static/js/bulma-slider.min.js"></script>
<script src="./static/js/index.js"></script>
</head>
<body>
<section class="hero">
<div class="hero-body">
<div class="container is-max-desktop">
<div class="columns is-centered">
<div class="column has-text-centered">
<h1 class="title is-1 publication-title">MotionGPT: Finetuned LLMs are General-Purpose Motion Generators</h1>
<div class="is-size-5 publication-authors">
<span class="author-block">
Yaqi Zhang<sup>1,2</sup>,</span>
<span class="author-block">
Di Huang<sup>4</sup>,</span>
<span class="author-block">
Bin Liu<sup>1,2</sup>,
</span>
<span class="author-block">
Shixiang Tang<sup>4</sup>,
</span>
<span class="author-block">
Yan Lu<sup>4</sup>,
</span>
<span class="author-block">
Lu Chen<sup>5</sup>,
</span>
<span class="author-block">
Lei Bai<sup>3</sup>,
</span>
<span class="author-block">
Qi Chu<sup>1,2</sup>,
</span>
<span class="author-block">
Nenghai Yu<sup>1,2</sup>,
</span>
<span class="author-block">
Wanli Ouyang<sup>3</sup>
</span>
</div>
<div class="is-size-5 publication-authors">
<span class="author-block"><sup>1</sup>University of Science and Technology of China</span>
<span class="author-block"><sup>2</sup>CAS Key Laboratory of Electromagnetic Space Information</span>
<span class="author-block"><sup>3</sup>Shanghai AI Laboratory</span>
<span class="author-block"><sup>4</sup>The University of Sydney</span>
<span class="author-block"><sup>5</sup>Zhejiang University</span>
</div>
<div class="column has-text-centered">
<div class="publication-links">
<span class="link-block">
<a href="https://arxiv.org/abs/2306.10900"
class="external-link button is-normal is-rounded is-dark">
<span class="icon">
<i class="ai ai-arxiv"></i>
</span>
<span>arXiv</span>
</a>
</span>
<span class="link-block">
<a href="https://github.com/qiqiApink/MotionGPT"
class="external-link button is-normal is-rounded is-dark">
<span class="icon">
<i class="fab fa-github"></i>
</span>
<span>Code</span>
</a>
</span>
</div>
</div>
</div>
</div>
</div>
</div>
</section>
<section class="hero teaser">
<div class="container is-max-desktop">
<div class="hero-body">
<video poster="" id="tree" playsinline autoplay muted loop height="80%">
<source src="./static/videos/teaser.mp4"
type="video/mp4">
</video>
<h2 class="subtitle has-text-centered">
MotionGPT supports diverse control conditions for human motion generation by finetuning LLMs.
</h2>
</div>
</div>
</section>
<section class="section">
<div class="container is-max-desktop">
<!-- Abstract. -->
<div class="columns is-centered has-text-centered">
<div class="column is-four-fifths">
<h2 class="title is-3">Abstract</h2>
<div class="content has-text-justified">
<p>
Generating realistic human motion from given action descriptions has experienced significant
advancements because of the emerging requirement of digital humans. While recent works have
achieved impressive results in generating motion directly from textual action descriptions,
they often support only a single modality of the control signal, which limits their application
in the real digital human industry. This paper presents a <b>Motion G</b>eneral-<b>P</b>urpose genera<b>T</b>or
(MotionGPT) that can use multimodal control signals, <i>e.g.</i>, text and single-frame poses, for
generating consecutive human motions by treating multimodal signals as special input tokens in
large language models (LLMs). Specifically, we first quantize multimodal control signals into
discrete codes and then formulate them in a unified prompt instruction to ask the LLMs to generate
the motion answer. Our MotionGPT demonstrates a unified human motion generation model with
multimodal control signals by tuning a mere 0.4% of LLM parameters. To the best of our knowledge,
MotionGPT is the first method to generate human motion by multimodal control signals, which we
hope can shed light on this new direction.
</p>
<img src="./static/images/teaser.png" alt="Teaser image.">
<div class="content has-text-justified">
<p>Compared with previous methods, MotionGPT has the unique ability to accept multiple control conditions and solve various motion generation tasks using a unified model.</p>
</div>
</div>
</div>
</div>
</div>
</section>
<section class="section">
<div class="container is-max-desktop">
<div class="columns is-centered has-text-centered">
<div class="column is-four-fifths">
<h2 class="title is-3">Pipeline</h2>
<img src="./static/images/pipeline.png" alt="Pipeline image." />
<div class="content has-text-justified">
<p>
Our MotionGPT (<b>Motion G</b>eneral-<b>P</b>urpose genera<b>T</b>or) has
the unique ability to accept multiple control conditions and solve various
motion generation tasks using a unified model. Given text and poses as an
input example, we organize task descriptions (Instruction) and multiple
control conditions (Input) within a question template. MotionGPT fine-tunes
a LLM with LoRA to generate the corresponding motion answer, which can then
be decoded into human motions using a VQ-VAE decoder.
</p>
</div>
</div>
</div>
</div>
</section>
<section class="hero is-small">
<div class="hero-body">
<div class="container">
<div class="columns is-centered has-text-centered">
<div class="column is-four-fifths">
<h2 class="title is-3">Text-to-motion Generation</h2>
<p><b><font face="verdana">the generated motion is in <font color="#f5a623">orange</font></font></b></p>
<div class="content has-text-justified">
<div id="results-carousel" class="carousel results-carousel">
<div class="column is-centered has-text-centered">
<!-- <h2 class="title is-6">a person walks forward at an angle to the right, then swings their left hand, a person walks forward at an angle to the right, then swings their left hand</h2> -->
<div style="width:500px ;height:100px;"><p><i>a person walks forward, turns and then sits on a chair</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/t2m_0.mp4"
type="video/mp4">
</video>
</div>
<div class="column is-centered has-text-centered">
<!-- <h2 class="title is-4">3</h2> -->
<div style="width:500px ;height:100px;"><p><i>a hunched individual slowly wobbles forward in a drunken manner</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/t2m_1.mp4"
type="video/mp4">
</video>
</div>
<div class="column is-centered has-text-centered">
<!-- <h2 class="title is-4">4</h2> -->
<div style="width:500px ;height:100px;"><p><i>a person walks forward at an angle to the right, then swings their left hand</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/t2m_2.mp4"
type="video/mp4">
</video>
</div>
<div class="column is-centered has-text-centered">
<div style="width:500px ;height:100px;"><p><i>a person stirs something with his left hand</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/t2m_3.mp4"
type="video/mp4">
</video>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</section>
<section class="hero is-small">
<div class="hero-body">
<div class="container">
<div class="columns is-centered has-text-centered">
<div class="column is-four-fifths">
<h2 class="title is-3">(Text,initial pose)-to-motion Generation</h2>
<p><b><font face="verdana">the generated motion is in <font color="#f5a623">orange</font> and we highlight the initial pose in <font color="#4a90e2">blue (remain frozen for 0.5s)</font></font></b></p>
<div class="content has-text-justified">
<div id="results-carousel" class="carousel results-carousel">
<div class="column is-centered has-text-centered">
<div style="width:500px ;height:100px;"><p><i>a person slowly walked forward and returned</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/initial_0.mp4"
type="video/mp4">
</video>
</div>
<div class="column is-centered has-text-centered">
<div style="width:500px ;height:100px;"><p><i>person is running from side to side</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/initial_1.mp4"
type="video/mp4">
</video>
</div>
<div class="column is-centered has-text-centered">
<div style="width:500px ;height:100px;"><p>a person slowly walked forward while balancing</p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/initial_2.mp4"
type="video/mp4">
</video>
</div>
<div class="column is-centered has-text-centered">
<div style="width:500px ;height:100px;"><p>a person walks forward very slowly</p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/initial_3.mp4"
type="video/mp4">
</video>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</section>
<section class="hero is-small">
<div class="hero-body">
<div class="container">
<div class="columns is-centered has-text-centered">
<div class="column is-four-fifths">
<h2 class="title is-3">(Text,last pose)-to-motion Generation</h2>
<p><b><font face="verdana">the generated motion is in <font color="#f5a623">orange</font> and we highlight the last pose in <font color="#4a90e2">blue (remain frozen for 0.5s)</font></font></b></p>
<div class="content has-text-justified">
<div id="results-carousel" class="carousel results-carousel">
<div class="column is-centered has-text-centered">
<div style="width:500px ;height:100px;"><p><i>a person with his arms bent kicks to side with his left foot</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/last_0.mp4"
type="video/mp4">
</video>
</div>
<div class="column is-centered has-text-centered">
<div style="width:500px ;height:100px;"><p><i>a person turns right while walking then stops</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/last_1.mp4"
type="video/mp4">
</video>
</div>
<div class="column is-centered has-text-centered">
<div style="width:500px ;height:100px;"><p><i>walking backwards and then stopping</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/last_2.mp4"
type="video/mp4">
</video>
</div>
<div class="column is-centered has-text-centered">
<div style="width:500px ;height:100px;"><p><i>swinging hands up and down</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/last_3.mp4"
type="video/mp4">
</video>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</section>
<section class="hero is-small">
<div class="hero-body">
<div class="container">
<div class="columns is-centered has-text-centered">
<div class="column is-four-fifths">
<h2 class="title is-3">(Text,key poses)-to-motion Generation</h2>
<p><b><font face="verdana">the generated motion is in <font color="#f5a623">orange</font> and we highlight key poses in <font color="#4a90e2">blue (remain frozen for 0.5s)</font></font></b></p>
<div class="content has-text-justified">
<div id="results-carousel" class="carousel results-carousel">
<div class="column is-centered has-text-centered">
<div style="width:500px ;height:100px;"><p><i>a walking person suddenly gets staggered to their left, then recovers</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/keys_0.mp4"
type="video/mp4">
</video>
</div>
<div class="column is-centered has-text-centered">
<!-- <h2 class="title is-4">2</h2> -->
<div style="width:500px ;height:100px;"><p><i>standing on one leg and swinging it</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/keys_1.mp4"
type="video/mp4">
</video>
</div>
<div class="column is-centered has-text-centered">
<div style="width:500px ;height:100px;"><p><i>the man dances around waving his arms and kicking his legs</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/keys_2.mp4"
type="video/mp4">
</video>
</div>
<div class="column is-centered has-text-centered">
<div style="width:500px ;height:100px;"><p><i>a person does multiple jumping jacks</i></p></div>
<video poster="" id="tree" playsinline autoplay muted loop height="100%">
<source src="./static/videos/keys_3.mp4"
type="video/mp4">
</video>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</section>
<section class="section" id="BibTeX">
<div class="container is-max-desktop content">
<h2 class="title">BibTeX</h2>
<pre><code>
@article{zhang2023motiongpt,
title={MotionGPT: Finetuned LLMs are General-Purpose Motion Generators},
author={Zhang, Yaqi and Huang, Di and Liu, Bin and Tang, Shixiang and Lu, Yan and Chen, Lu and Bai, Lei and Chu, Qi and Yu, Nenghai and Ouyang, Wanli},
journal={arXiv preprint arXiv:2306.10900},
year={2023}
}
</code></pre>
</div>
</section>
<footer class="footer">
<div class="container">
<div class="content has-text-centered">
<a class="icon-link"
href="./static/images/motiongpt_paper.pdf">
<i class="fas fa-file-pdf"></i>
</a>
<a class="icon-link" href="https://github.com/qiqiApink/MotionGPT" class="external-link" disabled>
<i class="fab fa-github"></i>
</a>
</div>
<div class="columns is-centered">
<div class="column is-8">
<div class="content">
<p>
This website is licensed under a <a rel="license"
href="http://creativecommons.org/licenses/by-sa/4.0/">Creative
Commons Attribution-ShareAlike 4.0 International License</a>.
</p>
<p>
Website source code based on the <a href="https://nerfies.github.io">Nerfies</a> project page.
If you want to reuse their <a href="https://github.com/nerfies/nerfies.github.io">source code</a>,
please credit them appropriately.
</p>
</div>
</div>
</div>
</div>
</footer>
</body>
</html>
================================================
FILE: lit_llama/__init__.py
================================================
from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope
from lit_llama.tokenizer import Tokenizer
================================================
FILE: lit_llama/adapter.py
================================================
"""Implementation of the paper:
LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
https://arxiv.org/abs/2303.16199
"""
# mypy: ignore-errors
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import lit_llama.model as llama
from lit_llama.model import build_rope_cache, apply_rope, RMSNorm, MLP
@dataclass
class LLaMAConfig(llama.LLaMAConfig):
adapter_prompt_length: int = 10
adapter_start_layer: int = 2
class CausalSelfAttention(nn.Module):
"""A modification of `lit_llama.model.CausalSelfAttention` that adds the attention
over the adaption prompt."""
def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
if block_idx >= config.adapter_start_layer:
# adapter embedding layer
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
# gate for adaption
self.gating_factor = torch.nn.Parameter(torch.zeros(1))
self.n_head = config.n_head
self.n_embd = config.n_embd
self.block_size = config.block_size
self.block_idx = block_idx
self.adapter_prompt_length = config.adapter_prompt_length
self.adapter_start_layer = config.adapter_start_layer
self.rope_cache = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
head_size = C // self.n_head
k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
if self.rope_cache is None:
# cache for future forward calls
self.rope_cache = build_rope_cache(
seq_len=self.block_size,
n_elem=self.n_embd // self.n_head,
dtype=x.dtype,
device=x.device,
)
q = apply_rope(q, self.rope_cache)
k = apply_rope(k, self.rope_cache)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
# att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
# att = F.softmax(att, dim=-1)
# y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
if self.block_idx >= self.adapter_start_layer:
prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd)
aT = prefix.size(1)
_, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2)
ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)
av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)
amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device)
ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False)
y = y + self.gating_factor * ay
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
return y
class Block(nn.Module):
"""The implementation is identical to `lit_llama.model.Block` with the exception that
we replace the attention layer where adaption is implemented."""
def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
super().__init__()
self.rms_1 = RMSNorm(config.n_embd)
self.attn = CausalSelfAttention(config, block_idx)
self.rms_2 = RMSNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.rms_1(x))
x = x + self.mlp(self.rms_2(x))
return x
class LLaMA(llama.LLaMA):
"""The implementation is identical to `lit_llama.model.LLaMA` with the exception that
the `Block` saves the layer index and passes it down to the attention layer."""
def __init__(self, config: LLaMAConfig) -> None:
nn.Module.__init__(self)
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
h=nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
ln_f=RMSNorm(config.n_embd),
)
)
@classmethod
def from_name(cls, name: str):
return cls(LLaMAConfig.from_name(name))
def mark_only_adapter_as_trainable(model: LLaMA) -> None:
"""Sets `requires_grad=False` for all non-adapter weights."""
for name, param in model.named_parameters():
param.requires_grad = "adapter_wte" in name or "gating_factor" in name
def adapter_state_from_state_dict(state_dict: dict) -> dict:
"""Returns the model state dict with only the adapter weights for saving."""
return {name: param for name, param in state_dict.items() if "adapter_wte" in name or "gating_factor" in name}
================================================
FILE: lit_llama/indexed_dataset.py
================================================
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of the FairSeq source tree.
# copied from fairseq/fairseq/data/indexed_dataset.py
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
# other slight modifications to remove fairseq dependencies
# Added document index to index file and made it accessible.
# An empty sentence no longer separates documents.
from functools import lru_cache
import os
import shutil
import struct
from itertools import accumulate
import numpy as np
import torch
def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500:
return np.uint16
else:
return np.int32
def get_available_dataset_impl():
return ['lazy', 'cached', 'mmap']
def infer_dataset_impl(path):
if IndexedDataset.exists(path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
if magic == IndexedDataset._HDR_MAGIC:
return 'cached'
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
return 'mmap'
else:
return None
else:
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
def make_builder(out_file, impl, vocab_size=None):
if impl == 'mmap':
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
else:
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, skip_warmup=False):
if not IndexedDataset.exists(path):
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
if impl == 'infer':
impl = infer_dataset_impl(path)
if impl == 'lazy' and IndexedDataset.exists(path):
return IndexedDataset(path)
elif impl == 'cached' and IndexedDataset.exists(path):
return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path, skip_warmup)
print(f"Unknown dataset implementation: {impl}")
return None
def dataset_exists(path, impl):
if impl == 'mmap':
return MMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
return a
def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float32,
7: np.float64,
8: np.uint16
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + '.idx'
def data_file_path(prefix_path):
return prefix_path + '.bin'
def create_doc_idx(sizes):
doc_idx = [0]
for i, s in enumerate(sizes):
if s == 0:
doc_idx.append(i + 1)
return doc_idx
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for IndexedDataset"""
_HDR_MAGIC = b'TNTIDX\x00\x00'
def __init__(self, path):
super().__init__()
self.path = path
self.data_file = None
self.read_index(path)
def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
assert magic == self._HDR_MAGIC, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
version = f.read(8)
assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16))
self.dtype = dtypes[code]
self._len, self.s = struct.unpack('<QQ', f.read(16))
self.doc_count = struct.unpack('<Q', f.read(8))
self.dim_offsets = read_longs(f, self._len + 1)
self.data_offsets = read_longs(f, self._len + 1)
self.sizes = read_longs(f, self.s)
self.doc_idx = read_longs(f, self.doc_count)
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb', buffering=0)
def check_index(self, i):
if i < 0 or i >= self._len:
raise IndexError('index out of range')
def __del__(self):
if self.data_file:
self.data_file.close()
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if not self.data_file:
self.read_data(self.path)
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
return a
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
size = sum(sizes)
a = np.empty(size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[start] * self.element_size)
self.data_file.readinto(a)
offsets = list(accumulate(sizes))
sents = np.split(a, offsets[:-1])
return sents
def __len__(self):
return self._len
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path):
super().__init__(path)
self.cache = None
self.cache_index = {}
@property
def supports_prefetch(self):
return True
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices))
total_size = 0
for i in indices:
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
self.cache = np.empty(total_size, dtype=self.dtype)
ptx = 0
self.cache_index.clear()
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx: ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
if self.data_file:
# close and delete data file after prefetch so we can pickle
self.data_file.close()
self.data_file = None
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx: ptx + a.size])
return a
elif isinstance(idx, slice):
# Hack just to make this work, can optimizer later if necessary
sents = []
for i in range(*idx.indices(len(self))):
sents.append(self[i])
return sents
class IndexedDatasetBuilder(object):
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float32: 4,
np.float64: 8
}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb')
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
self.sizes = []
self.element_size = self.element_sizes[self.dtype]
self.doc_idx = [0]
def add_item(self, tensor):
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
for s in tensor.size():
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def end_document(self):
self.doc_idx.append(len(self.sizes))
def merge_file_(self, another_file):
index = IndexedDataset(another_file)
assert index.dtype == self.dtype
doc_offset = len(self.sizes)
begin = self.data_offsets[-1]
for data_offset in index.data_offsets[1:]:
self.data_offsets.append(begin + data_offset)
self.sizes.extend(index.sizes)
begin = self.dim_offsets[-1]
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)
self.doc_idx.extend((doc_offset + index.doc_idx)[1:])
with open(data_file_path(another_file), 'rb') as f:
while True:
data = f.read(1024)
if data:
self.out_file.write(data)
else:
break
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, 'wb')
index.write(b'TNTIDX\x00\x00')
index.write(struct.pack('<Q', 1))
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
index.write(struct.pack('<Q', len(self.doc_idx)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
write_longs(index, self.doc_idx)
index.close()
def _warmup_mmap_file(path):
with open(path, 'rb') as stream:
while stream.read(100 * 1024 * 1024):
pass
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b'MMIDIDX\x00\x00'
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, 'wb')
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack('<Q', 1))
self._file.write(struct.pack('<B', code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)
self._file.write(struct.pack('<Q', len(sizes)))
self._file.write(struct.pack('<Q', len(doc_idx)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order='C'))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order='C'))
del pointers
doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order='C'))
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path, skip_warmup=False):
with open(path, 'rb') as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
version = struct.unpack('<Q', stream.read(8))
assert (1,) == version
dtype_code, = struct.unpack('<B', stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack('<Q', stream.read(8))[0]
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
offset = stream.tell()
if not skip_warmup:
print(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print(" reading sizes...")
self._sizes = np.frombuffer(
self._bin_buffer,
dtype=np.int32,
count=self._len,
offset=offset)
print(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes)
print(" reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path, skip_warmup)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state, skip_warmup=True)
def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
print(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
print(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, (int, np.integer)):
ptr, size = self._index[idx]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=size, offset=ptr)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1])
return sents
else:
raise TypeError("Unexpected type received for idx: {}".format(type(idx)))
def get(self, idx, offset=0, length=None):
""" Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=length, offset=ptr)
return np_array
@property
def sizes(self):
return self._index.sizes
@property
def doc_idx(self):
return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, 'wb')
self._dtype = dtype
self._sizes = []
self._doc_idx = [0]
@property
def dtype(self):
return self._dtype
def add_item(self, np_array):
# np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order='C'))
self._sizes.append(np_array.size)
def add_doc(self, np_array, sizes):
# np_array = np.array(tensor, dtype=self._dtype)
self._data_file.write(np_array.tobytes(order='C'))
self._sizes.extend(sizes)
self._doc_idx.append(len(self._sizes))
def end_document(self):
self._doc_idx.append(len(self._sizes))
def merge_file_(self, another_file):
# Concatenate index
index = MMapIndexedDataset.Index(index_file_path(another_file))
assert index.dtype == self._dtype
offset = len(self._sizes)
self._sizes.extend(index.sizes)
self._doc_idx.extend((offset + index.doc_idx)[1:])
# Concatenate data
with open(data_file_path(another_file), 'rb') as f:
shutil.copyfileobj(f, self._data_file)
def finalize(self, index_file):
self._data_file.close()
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes, self._doc_idx)
================================================
FILE: lit_llama/lora.py
================================================
# Derived from https://github.com/microsoft/LoRA
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Dict, List
import lit_llama.model as llama
from contextlib import contextmanager
from dataclasses import dataclass
class LoRALayer():
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
class MergedLinear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
enable_lora: List[bool] = [False],
fan_in_fan_out: bool = False,
merge_weights: bool = True,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
merge_weights=merge_weights)
assert out_features % len(enable_lora) == 0, \
'The length of enable_lora must divide out_features'
self.enable_lora = enable_lora
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0 and any(enable_lora):
self.lora_A = nn.Parameter(
self.weight.new_zeros((r * sum(enable_lora), in_features)))
self.lora_B = nn.Parameter(
self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
) # weights for Conv1D with groups=sum(enable_lora)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
self.lora_ind = self.weight.new_zeros(
(out_features, ), dtype=torch.bool
).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def zero_pad(self, x):
result = x.new_zeros((*x.shape[:-1], self.out_features))
result = result.view(-1, self.out_features)
result[:, self.lora_ind] = x.reshape(
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
)
return result.view((*x.shape[:-1], self.out_features))
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0 and any(self.enable_lora):
delta_w = F.conv1d(
self.lora_A.data.unsqueeze(0),
self.lora_B.data.unsqueeze(-1),
groups=sum(self.enable_lora)
).squeeze(0)
self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0 and any(self.enable_lora):
delta_w = F.conv1d(
self.lora_A.data.unsqueeze(0),
self.lora_B.data.unsqueeze(-1),
groups=sum(self.enable_lora)
).squeeze(0)
self.weight.data += self.zero_pad(T(delta_w * self.scaling))
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
if self.merged:
return F.linear(x, T(self.weight), bias=self.bias)
else:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
after_A = F.linear(self.lora_dropout(x), self.lora_A)
after_B = F.conv1d(
after_A.transpose(-2, -1),
self.lora_B.unsqueeze(-1),
groups=sum(self.enable_lora)
).transpose(-2, -1)
result += self.zero_pad(after_B) * self.scaling
return result
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
for n, p in model.named_parameters():
if 'lora_' not in n and 'motion_proj' not in n and 'llama_proj' not in n:
p.requires_grad = False
if bias == 'none':
return
elif bias == 'all':
for n, p in model.named_parameters():
if 'bias' in n:
p.requires_grad = True
elif bias == 'lora_only':
for m in model.modules():
if isinstance(m, LoRALayer) and \
hasattr(m, 'bias') and \
m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
my_state_dict = model.state_dict()
if bias == 'none':
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'llama_proj' in k or 'motion_proj' in k}
elif bias == 'all':
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k or 'llama_proj' in k or 'motion_proj' in k}
elif bias == 'lora_only':
to_return = {}
for k in my_state_dict:
if 'lora_' in k:
to_return[k] = my_state_dict[k]
bias_name = k.split('lora_')[0]+'bias'
if bias_name in my_state_dict:
to_return[bias_name] = my_state_dict[bias_name]
return to_return
else:
raise NotImplementedError
@dataclass
class LoRAConfig:
r: float = 0.0
alpha: float = 1.0
dropout: float = 0.0
class CausalSelfAttention(llama.CausalSelfAttention):
lora_config = None
def __init__(self, config: llama.LLaMAConfig) -> None:
# Skip the parent class __init__ altogether and replace it to avoid
# useless allocations
nn.Module.__init__(self)
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = MergedLinear(
in_features=config.n_embd,
out_features=3 * config.n_embd,
r=self.lora_config.r,
lora_alpha=self.lora_config.alpha,
lora_dropout=self.lora_config.dropout,
enable_lora=[True, False, True],
fan_in_fan_out = False,
merge_weights=True,
bias=False)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
# regularization
self.n_head = config.n_head
self.n_embd = config.n_embd
self.block_size = config.block_size
self.rope_cache = None
@contextmanager
def lora(r, alpha, dropout, enabled: bool = True):
"""A context manager under which you can instantiate the model with LoRA."""
if not enabled:
yield
return
CausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout)
causal_self_attention = llama.CausalSelfAttention
llama.CausalSelfAttention = CausalSelfAttention
yield
llama.CausalSelfAttention = causal_self_attention
CausalSelfAttention.lora_config = None
================================================
FILE: lit_llama/model.py
================================================
"""Full definition of a LLaMA Language Model, all of it in this single file.
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
# mypy: ignore-errors
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing_extensions import Self
@dataclass
class LLaMAConfig:
block_size: int = 4096
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
n_embd: int = 4096
@classmethod
def from_name(cls, name: str) -> Self:
return cls(**llama_configs[name])
llama_configs = {
"7B": dict(n_layer=32, n_head=32, n_embd=4096),
"13B": dict(n_layer=40, n_head=40, n_embd=5120),
"30B": dict(n_layer=60, n_head=52, n_embd=6656),
"65B": dict(n_layer=80, n_head=64, n_embd=8192),
}
class LLaMA(nn.Module):
def __init__(self, config: LLaMAConfig) -> None:
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=RMSNorm(config.n_embd),
)
)
# self.llama_proj = nn.Sequential(
# nn.Linear(256, 1024),
# nn.ReLU(),
# nn.Linear(1024, config.n_embd)
# )
self.llama_proj = nn.Linear(512, config.n_embd)
# self.motion_proj = nn.Sequential(
# nn.Linear(config.n_embd, 1024),
# nn.ReLU(),
# nn.Linear(1024, 256)
# )
self.motion_proj = nn.Linear(config.n_embd, 512)
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
def forward(self, idx: torch.Tensor) -> torch.Tensor:
_, t = idx.size()
assert (
t <= self.config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
# forward the LLaMA model itself
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x) # (b, t, vocab_size)
return logits
@classmethod
def from_name(cls, name: str) -> Self:
return cls(LLaMAConfig.from_name(name))
class Block(nn.Module):
def __init__(self, config: LLaMAConfig) -> None:
super().__init__()
self.rms_1 = RMSNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.rms_2 = RMSNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.rms_1(x))
x = x + self.mlp(self.rms_2(x))
return x
class CausalSelfAttention(nn.Module):
def __init__(self, config: LLaMAConfig) -> None:
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.block_size = config.block_size
self.rope_cache = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
head_size = C // self.n_head
k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
if self.rope_cache is None:
# cache for future forward calls
self.rope_cache = build_rope_cache(
seq_len=self.block_size,
n_elem=self.n_embd // self.n_head,
dtype=x.dtype,
device=x.device,
)
q = apply_rope(q, self.rope_cache)
k = apply_rope(k, self.rope_cache)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
# att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
# att = F.softmax(att, dim=-1)
# y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config: LLaMAConfig) -> None:
super().__init__()
hidden_dim = 4 * config.n_embd
n_hidden = int(2 * hidden_dim / 3)
N = 256
# ensure n_hidden is multiple of N
n_hidden = ((n_hidden - 1) // N) * N + N
self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
x = self.c_proj(x)
return x
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization.
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
"""
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
super().__init__()
self.scale = nn.Parameter(torch.ones(size))
self.eps = eps
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE: the original RMSNorm paper implementation is not equivalent
# norm_x = x.norm(2, dim=self.dim, keepdim=True)
# rms_x = norm_x * d_x ** (-1. / 2)
# x_normed = x / (rms_x + self.eps)
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
return self.scale * x_normed
def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor:
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta)
# Compute cache. Because polar only takes float32 or float64, we need to cast
# when working with 16 bit floats (float16 or bfloat16)
dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8]
working_dtype = (
torch.float32 if dtype in dtypes_requiring_casting else dtype
)
complex_dtype = (
torch.complex32 if dtype in dtypes_requiring_casting else torch.complex64
)
cache = torch.polar(
torch.ones_like(idx_theta).to(working_dtype), idx_theta.to(working_dtype)
).to(complex_dtype)
return cache
def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
# truncate to support variable sizes
T = x.size(1)
rope_cache = rope_cache[:T]
# cast because `view_as_complex` does not support 16 bit tensors
xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3))
x_out = torch.view_as_real(xc * rope_cache).flatten(3)
return x_out.transpose(1, 2).type_as(x)
================================================
FILE: lit_llama/quantization.py
================================================
import os
from contextlib import contextmanager
import warnings
import math
import torch
# configuration for bitsandbytes before import
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
warnings.filterwarnings(
"ignore",
message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization"
)
warnings.filterwarnings(
"ignore",
message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization"
)
warnings.filterwarnings(
"ignore",
message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable."
)
try:
import bitsandbytes as bnb # noqa: E402
except:
bnb = None
if bnb is not None:
class Linear8bitLt(bnb.nn.Linear8bitLt):
"""Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and
re-quantizaton when loading the state dict.
This should only be used for inference. For training, use `bnb.nn.Linear8bitLt` directly.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, has_fp16_weights=False, threshold=6.0)
# We quantize the initial weight here so we don't end up filling the device
# memory with float32 weights which could lead to OOM.
self._quantize_weight(self.weight.data)
def _load_from_state_dict(self, local_state_dict, *args, **kwargs):
# There is only one key that ends with `*.weight`, the other one is the bias
weight_key = next((name for name in local_state_dict.keys() if name.endswith("weight")), None)
if weight_key is None:
return
# Load the weight from the state dict and re-quantize it
weight = local_state_dict.pop(weight_key)
self._quantize_weight(weight)
# If there is a bias, let nn.Module load it
if local_state_dict:
super()._load_from_state_dict(local_state_dict, *args, **kwargs)
def _quantize_weight(self, weight: torch.Tensor) -> None:
# This code is taken and adapted from `bnb.nn.Int8Params.cuda()`
B = weight.contiguous().half().cuda()
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
del SCBt
self.weight.data = CB
setattr(self.weight, "CB", CB)
setattr(self.weight, "SCB", SCB)
# for correctness but with terrible perf
class ColBlockQuantizedLinear(torch.nn.Module):
def __init__(self, in_features, out_features, bias: bool, *, bits, tile_cols):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.tile_cols = tile_cols if tile_cols != -1 else self.in_features
self.bits = bits
self.entries_per_byte = 8 // bits
assert self.entries_per_byte > 0 and self.entries_per_byte * self.bits == 8
assert in_features % self.entries_per_byte == 0
self.register_buffer("quant_weight", torch.empty((self.out_features, self.in_features // self.entries_per_byte), dtype=torch.uint8))
self.register_buffer("scales", torch.empty((self.out_features, (self.in_features + self.tile_cols - 1) // self.tile_cols)))
self.register_buffer("zeros", torch.empty_like(self.scales))
assert isinstance(bias, bool)
if bias:
self.register_buffer("bias", torch.empty((self.out_features,)))
else:
self.register_buffer("bias", None)
def pack_weight(self, weight):
weight = weight.to(device=self.quant_weight.device, copy=True)
for j in range(self.scales.size(1)):
weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] /= self.scales[: , j: j+1]
weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] += self.zeros[: , j: j+1]
weight = weight.clamp_(min=0, max=2 ** self.bits - 1).to(dtype=torch.uint8)
self.quant_weight.zero_()
for nr in range(self.entries_per_byte):
self.quant_weight += weight[:, nr::self.entries_per_byte] << (nr * self.bits)
def get_weight(self, dtype=torch.float):
weight = torch.empty((self.out_features, self.in_features), device=self.quant_weight.device, dtype=dtype)
mask = (1<<self.bits) - 1
for nr in range(self.entries_per_byte):
weight[:, nr::self.entries_per_byte] = ((self.quant_weight >> (nr * self.bits)) & mask).float()
self.quant_weight.to(dtype)
for j in range(self.scales.size(1)):
weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] -= self.zeros[: , j: j+1]
weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] *= self.scales[: , j: j+1]
return weight
def forward(self, inp):
weight = self.get_weight(dtype=inp.dtype)
return torch.nn.functional.linear(inp, weight, self.bias)
class GPTQQuantizer:
# The algorithm and code has been taken from https://github.com/IST-DASLab/gptq/
# E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323
# portions copyright by the authors licensed under the Apache License 2.0
# All errors are our own.
def __init__(self, linear_module, *, bits, perchannel=True, sym=False, blocksize=128, percdamp=.01, groupsize=-1, actorder=False):
assert isinstance(linear_module, torch.nn.Linear)
self.linear_module = linear_module
self.dev = self.linear_module.weight.device
self.rows = linear_module.weight.shape[0]
self.columns = linear_module.weight.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0
self.bits = bits
self.maxq = 2 ** bits - 1
self.perchannel = perchannel
self.sym = sym
self.blocksize = blocksize
self.percdamp = percdamp
self.groupsize = groupsize
self.actorder = actorder
self.tile_cols = self.columns if groupsize == -1 else groupsize
self.scales = torch.zeros((self.rows, (self.columns + self.tile_cols - 1) // self.tile_cols), dtype=self.linear_module.weight.dtype, device = self.dev)
self.zeros = torch.zeros_like(self.scales)
assert not (self.actorder and self.groupsize != -1), "The permutation trick does not work for grouped quantization"
@staticmethod
def quantize_weight(x, scale, zero, maxq):
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
x_rec = scale * (q - zero)
return x_rec
def find_params_weight(self, x):
dev = x.device
shape = x.shape
if self.perchannel:
x = x.flatten(1)
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
scale = (xmax - xmin) / self.maxq
if self.sym:
zero = torch.full_like(scale, (self.maxq + 1) / 2)
else:
zero = torch.round(-xmin / scale)
if not self.perchannel:
tmp = shape[0]
scale = scale.repeat(tmp)
zero = zero.repeat(tmp)
shape = [-1] + [1] * (len(shape) - 1)
scale = scale.reshape(shape)
zero = zero.reshape(shape)
return scale, zero
def collect_input_stats(self, _1, inp, _2):
inp = inp[0].detach()
self.last_inp = inp
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
# inp = inp.float()
inp = math.sqrt(2 / self.nsamples) * inp.float()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t())
def quantize(self):
W = self.linear_module.weight.detach().to(dtype=torch.float, copy=True)
scale, zero = self.find_params_weight(W)
self.scales[:] = scale
self.zeros[:] = zero
H = self.H
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
if self.actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
damp = self.percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
for i1 in range(0, self.columns, self.blocksize):
i2 = min(i1 + self.blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if self.groupsize != -1:
if (i1 + i) % self.groupsize == 0:
scale, zero = self.find_params_weight(W[:, (i1 + i):(i1 + i + self.groupsize)])
self.scales[:, (i1 + i) // self.groupsize] = scale
self.zeros[:, (i1 + i) // self.groupsize] = zeros
q = self.quantize_weight(
w.unsqueeze(1), scale, zero, self.maxq
)
q = q.squeeze(1)
assert q.dim() == 1
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d ** 2
err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
if self.actorder:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
weight = Q.reshape(self.linear_module.weight.shape).to(self.linear_module.weight.data.dtype)
error = torch.sum(Losses).item()
q_module = ColBlockQuantizedLinear(self.linear_module.in_features, self.linear_module.out_features, self.linear_module.bias is not None,
bits=self.bits, tile_cols=self.groupsize).to(self.dev)
q_module.scales = self.scales
q_module.zeros = self.zeros
q_module.pack_weight(weight)
q_module.bias = self.linear_module.bias
return q_module, error
================================================
FILE: lit_llama/tokenizer.py
================================================
import os
from pathlib import Path
from typing import Optional
import torch
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
class Tokenizer:
"""Tokenizer for LLaMA."""
def __init__(self, model_path: Path) -> None:
self.processor = SentencePieceProcessor(model_file=str(model_path))
self.bos_id = self.processor.bos_id()
self.eos_id = self.processor.eos_id()
self.pad_id = self.processor.pad_id()
@property
def vocab_size(self) -> int:
return self.processor.vocab_size()
def encode(
self,
string: str,
bos: bool = True,
eos: bool = False,
max_length: int = -1,
pad: bool = False,
device: Optional[torch.device] = None
) -> torch.Tensor:
tokens = self.processor.encode(string)
if bos:
tokens = [self.bos_id] + tokens
if eos:
tokens = tokens + [self.eos_id]
if max_length > 0:
tokens = tokens[:max_length]
if pad and len(tokens) < max_length:
tokens += [self.pad_id] * (max_length - len(tokens))
return torch.tensor(tokens, dtype=torch.int, device=device)
def decode(self, tokens: torch.Tensor) -> str:
return self.processor.decode(tokens.tolist())
@staticmethod
def train(input: str, destination: str, vocab_size=32000) -> None:
model_prefix = os.path.join(destination, "tokenizer")
SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size)
================================================
FILE: lit_llama/utils.py
================================================
"""Utility functions for training and inference."""
import functools
from pathlib import Path
import pickle
import warnings
from io import BytesIO
import torch
import torch.utils._device
from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
def save_model_checkpoint(fabric, model, file_path):
"""Handles boilerplate logic for retrieving and saving the state_dict.
This will be upstreamed to Fabric soon.
"""
file_path = Path(file_path)
if isinstance(fabric.strategy, DeepSpeedStrategy):
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
fabric.save(file_path, {"model": model})
fabric.barrier()
if fabric.global_rank == 0:
# Create a consolidated checkpoint with the same name next to the deepspeed checkpoint
convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth"))
return
if isinstance(fabric.strategy, FSDPStrategy):
save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
state_dict = model._forward_module.state_dict()
else:
state_dict = model.state_dict()
if fabric.global_rank == 0:
torch.save(state_dict, file_path)
fabric.barrier()
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
def __init__(self, device=None, dtype=None, quantization_mode=None):
"""
Create tensors with given device and dtype and don't run initialization
(but instead use "empty tensors", i.e. uninitialized memory).
device: `torch.device` to work with
dtype: `torch.dtype` to work with
quantization_mode: optional string, quantization mode to work with, default `None`.
Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU)
`qptq.int4`, `gptq.int8`: GPTQ pre-quantized models
Example::
with EmptyInitOnDevice("cuda", dtype=torch.bfloat16):
model = LLaMA.from_name('7B')
model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))"""
self.quantization_mode = quantization_mode
self.quantized_linear_cls = None
if self.quantization_mode == 'llm.int8':
if device.type != "cuda":
raise ValueError("Quantization is only supported on the GPU.")
from .quantization import Linear8bitLt
self.quantized_linear_cls = Linear8bitLt
elif self.quantization_mode == 'gptq.int4':
from .quantization import ColBlockQuantizedLinear
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
elif self.quantization_mode == 'gptq.int8':
from .quantization import ColBlockQuantizedLinear
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
elif self.quantization_mode is not None:
raise RuntimeError(f"unknown quantization mode {self.quantization_mode}")
self.device = device
self.dtype = dtype
def __enter__(self):
if self.quantized_linear_cls != None:
self.torch_linear_cls = torch.nn.Linear
torch.nn.Linear = self.quantized_linear_cls
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
if self.quantized_linear_cls != None:
torch.nn.Linear = self.torch_linear_cls
return super().__exit__(exc_type, exc_val, exc_tb)
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if getattr(func, "__module__", None) == "torch.nn.init":
if "tensor" in kwargs:
return kwargs["tensor"]
else:
return args[0]
if (
self.device is not None
and func in torch.utils._device._device_constructors()
and kwargs.get("device") is None
):
kwargs["device"] = self.device
if (
self.dtype is not None
and func in torch.utils._device._device_constructors()
and kwargs.get("dtype") is None
):
kwargs["dtype"] = self.dtype
return func(*args, **kwargs)
# this is taken from torchhacks https://github.com/lernapparat/torchhacks
class NotYetLoadedTensor:
def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):
self.metatensor = metatensor
self.archiveinfo = archiveinfo
self.storageinfo = storageinfo
self.rebuild_args = rebuild_args
@classmethod
def rebuild(
cls,
storage,
storage_offset,
size,
stride,
requires_grad,
backward_hooks,
metadata=None,
archiveinfo=None,
):
rebuild_args = (
storage_offset,
size,
stride,
requires_grad,
backward_hooks,
metadata,
)
metatensor = torch._utils._rebuild_tensor_v2(
storage,
storage_offset,
size,
stride,
requires_grad,
backward_hooks,
metadata,
)
storageinfo = storage.archiveinfo
return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)
def _load_tensor(self):
name, storage_cls, fn, device, size = self.storageinfo
dtype = self.metatensor.dtype
uts = (
self.archiveinfo.zipfile.get_storage_from_record(
f"data/{fn}",
size * torch._utils._element_size(dtype),
torch.UntypedStorage,
)
._typed_storage()
._untyped_storage
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
storage = torch.storage.TypedStorage(
wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True
)
tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
return tensor
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
loaded_args = [
(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args
]
res = func(*loaded_args, **kwargs)
# gc.collect would be costly here, maybe do it optionally
return res
def __getattr__(self, name):
# properties
## TODO: device, is_...??
## TODO: mH, mT, H, T, data, imag, real
## name ???
if name in {
"dtype",
"grad",
"grad_fn",
"layout",
"names",
"ndim",
"output_nr",
"requires_grad",
"retains_grad",
"shape",
"volatile",
}:
return getattr(self.metatensor, name)
if name in {"size"}:
return getattr(self.metatensor, name)
# materializing with contiguous is needed for quantization
if name in {"contiguous"}:
return getattr(self._load_tensor(), name)
raise AttributeError(f"{type(self)} does not have {name}")
def __repr__(self):
return f"NotYetLoadedTensor({repr(self.metatensor)})"
class LazyLoadingUnpickler(pickle.Unpickler):
def __init__(self, file, zipfile):
super().__init__(file)
self.zipfile = zipfile
def find_class(self, module, name):
if module == "torch._utils" and name == "_rebuild_tensor_v2":
res = super().find_class(module, name)
return functools.partial(NotYetLoadedTensor.rebuild, archiveinfo=self)
return super().find_class(module, name)
def persistent_load(self, pid):
name, cls, fn, device, size = pid
with warnings.catch_warnings():
warnings.simplefilter("ignore")
s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")
s.archiveinfo = pid
return s
def lazy_load(fn):
zf = torch._C.PyTorchFileReader(str(fn))
with BytesIO(zf.get_record("data.pkl")) as pkl:
mup = LazyLoadingUnpickler(pkl, zf)
sd = mup.load()
return sd
================================================
FILE: models/encdec.py
================================================
import torch.nn as nn
from models.resnet import Resnet1D
class Encoder(nn.Module):
def __init__(self,
input_emb_width = 3,
output_emb_width = 512,
down_t = 3,
stride_t = 2,
width = 512,
depth = 3,
dilation_growth_rate = 3,
activation='relu',
norm=None):
super().__init__()
blocks = []
filter_t, pad_t = stride_t * 2, stride_t // 2
blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1))
blocks.append(nn.ReLU())
for i in range(down_t):
input_dim = width
block = nn.Sequential(
nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t),
Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm),
)
blocks.append(block)
blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1))
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
class Decoder(nn.Module):
def __init__(self,
input_emb_width = 3,
output_emb_width = 512,
down_t = 3,
stride_t = 2,
width = 512,
depth = 3,
dilation_growth_rate = 3,
activation='relu',
norm=None):
super().__init__()
blocks = []
filter_t, pad_t = stride_t * 2, stride_t // 2
blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1))
blocks.append(nn.ReLU())
for i in range(down_t):
out_dim = width
block = nn.Sequential(
Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, activation=activation, norm=norm),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv1d(width, out_dim, 3, 1, 1)
)
blocks.append(block)
blocks.append(nn.Conv1d(width, width, 3, 1, 1))
blocks.append(nn.ReLU())
blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1))
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
================================================
FILE: models/evaluator_wrapper.py
================================================
import torch
from os.path import join as pjoin
import numpy as np
from models.modules import MovementConvEncoder, TextEncoderBiGRUCo, MotionEncoderBiGRUCo
from utils.word_vectorizer import POS_enumerator
def build_models(opt):
movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
pos_size=opt.dim_pos_ohot,
hidden_size=opt.dim_text_hidden,
output_size=opt.dim_coemb_hidden,
device=opt.device)
motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
hidden_size=opt.dim_motion_hidden,
output_size=opt.dim_coemb_hidden,
device=opt.device)
checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
map_location=opt.device)
movement_enc.load_state_dict(checkpoint['movement_encoder'])
text_enc.load_state_dict(checkpoint['text_encoder'])
motion_enc.load_state_dict(checkpoint['motion_encoder'])
print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
return text_enc, motion_enc, movement_enc
class EvaluatorModelWrapper(object):
def __init__(self, opt):
if opt.dataset_name == 't2m':
opt.dim_pose = 263
elif opt.dataset_name == 'kit':
opt.dim_pose = 251
else:
raise KeyError('Dataset not Recognized!!!')
opt.dim_word = 300
opt.max_motion_length = 196
opt.dim_pos_ohot = len(POS_enumerator)
opt.dim_motion_hidden = 1024
opt.max_text_len = 20
opt.dim_text_hidden = 512
opt.dim_coemb_hidden = 512
# print(opt)
self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
self.opt = opt
self.device = opt.device
self.text_encoder.to(opt.device)
self.motion_encoder.to(opt.device)
self.movement_encoder.to(opt.device)
self.text_encoder.eval()
self.motion_encoder.eval()
self.movement_encoder.eval()
# Please note that the results does not following the order of inputs
def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
with torch.no_grad():
word_embs = word_embs.detach().to(self.device).float()
pos_ohot = pos_ohot.detach().to(self.device).float()
motions = motions.detach().to(self.device).float()
'''Movement Encoding'''
movements = self.movement_encoder(motions[..., :-4]).detach()
m_lens = m_lens // self.opt.unit_length
motion_embedding = self.motion_encoder(movements, m_lens)
'''Text Encoding'''
text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
return text_embedding, motion_embedding
# Please note that the results does not following the order of inputs
def get_motion_embeddings(self, motions, m_lens):
with torch.no_grad():
motions = motions.detach().to(self.device).float()
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
motions = motions[align_idx]
m_lens = m_lens[align_idx]
'''Movement Encoding'''
movements = self.movement_encoder(motions[..., :-4]).detach()
m_lens = m_lens // self.opt.unit_length
motion_embedding = self.motion_encoder(movements, m_lens)
return motion_embedding
================================================
FILE: models/modules.py
================================================
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
def init_weight(m):
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
nn.init.xavier_normal_(m.weight)
# m.bias.data.fill_(0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class MovementConvEncoder(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MovementConvEncoder, self).__init__()
self.main = nn.Sequential(
nn.Conv1d(input_size, hidden_size, 4, 2, 1),
nn.Dropout(0.2, inplace=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(hidden_size, output_size, 4, 2, 1),
nn.Dropout(0.2, inplace=True),
nn.LeakyReLU(0.2, inplace=True),
)
self.out_net = nn.Linear(output_size, output_size)
self.main.apply(init_weight)
self.out_net.apply(init_weight)
def forward(self, inputs):
inputs = inputs.permute(0, 2, 1)
outputs = self.main(inputs).permute(0, 2, 1)
# print(outputs.shape)
return self.out_net(outputs)
class TextEncoderBiGRUCo(nn.Module):
def __init__(self, word_size, pos_size, hidden_size, output_size, device):
super(TextEncoderBiGRUCo, self).__init__()
self.device = device
self.pos_emb = nn.Linear(pos_size, word_size)
self.input_emb = nn.Linear(word_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
self.output_net = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.LayerNorm(hidden_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hidden_size, output_size)
)
self.input_emb.apply(init_weight)
self.pos_emb.apply(init_weight)
self.output_net.apply(init_weight)
self.hidden_size = hidden_size
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
# input(batch_size, seq_len, dim)
def forward(self, word_embs, pos_onehot, cap_lens):
num_samples = word_embs.shape[0]
pos_embs = self.pos_emb(pos_onehot)
inputs = word_embs + pos_embs
input_embs = self.input_emb(inputs)
hidden = self.hidden.repeat(1, num_samples, 1)
cap_lens = cap_lens.data.tolist()
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
gru_seq, gru_last = self.gru(emb, hidden)
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
return self.output_net(gru_last)
class MotionEncoderBiGRUCo(nn.Module):
def __init__(self, input_size, hidden_size, output_size, device):
super(MotionEncoderBiGRUCo, self).__init__()
self.device = device
self.input_emb = nn.Linear(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
self.output_net = nn.Sequential(
nn.Linear(hidden_size*2, hidden_size),
nn.LayerNorm(hidden_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hidden_size, output_size)
)
self.input_emb.apply(init_weight)
self.output_net.apply(init_weight)
self.hidden_size = hidden_size
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
# input(batch_size, seq_len, dim)
def forward(self, inputs, m_lens):
num_samples = inputs.shape[0]
input_embs = self.input_emb(inputs)
hidden = self.hidden.repeat(1, num_samples, 1)
cap_lens = m_lens.data.tolist()
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True, enforce_sorted=False)
gru_seq, gru_last = self.gru(emb, hidden)
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
return self.output_net(gru_last)
================================================
FILE: models/quantize_cnn.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class QuantizeEMAReset(nn.Module):
def __init__(self, nb_code, code_dim, args):
super().__init__()
self.nb_code = nb_code
self.code_dim = code_dim
self.mu = args.mu
self.reset_codebook()
def reset_codebook(self):
self.init = False
self.code_sum = None
self.code_count = None
self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda())
def _tile(self, x):
nb_code_x, code_dim = x.shape
if nb_code_x < self.nb_code:
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
std = 0.01 / np.sqrt(code_dim)
out = x.repeat(n_repeats, 1)
out = out + torch.randn_like(out) * std
else :
out = x
return out
def init_codebook(self, x):
out = self._tile(x)
self.codebook = out[:self.nb_code]
self.code_sum = self.codebook.clone()
self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
self.init = True
@torch.no_grad()
def compute_perplexity(self, code_idx) :
# Calculate new centres
code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
code_count = code_onehot.sum(dim=-1) # nb_code
prob = code_count / torch.sum(code_count)
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
return perplexity
@torch.no_grad()
def update_codebook(self, x, code_idx):
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
code_sum = torch.matmul(code_onehot, x) # nb_code, w
code_count = code_onehot.sum(dim=-1) # nb_code
out = self._tile(x)
code_rand = out[:self.nb_code]
# Update centres
self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
self.codebook = usage * code_update + (1 - usage) * code_rand
prob = code_count / torch.sum(code_count)
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
return perplexity
def preprocess(self, x):
# NCT -> NTC -> [NT, C]
x = x.permute(0, 2, 1).contiguous()
x = x.view(-1, x.shape[-1])
return x
def quantize(self, x):
# Calculate latent code x_l
k_w = self.codebook.t()
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
keepdim=True) # (N * L, b)
_, code_idx = torch.min(distance, dim=-1)
return code_idx
def dequantize(self, code_idx):
x = F.embedding(code_idx, self.codebook)
return x
def forward(self, x):
N, width, T = x.shape
# Preprocess
x = self.preprocess(x)
# Init codebook if not inited
if self.training and not self.init:
self.init_codebook(x)
# quantize and dequantize through bottleneck
code_idx = self.quantize(x)
x_d = self.dequantize(code_idx)
# Update embeddings
if self.training:
perplexity = self.update_codebook(x, code_idx)
else :
perplexity = self.compute_perplexity(code_idx)
# Loss
commit_loss = F.mse_loss(x, x_d.detach())
# Passthrough
x_d = x + (x_d - x).detach()
# Postprocess
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
return x_d, commit_loss, perplexity
class Quantizer(nn.Module):
def __init__(self, n_e, e_dim, beta):
super(Quantizer, self).__init__()
self.e_dim = e_dim
self.n_e = n_e
self.beta = beta
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
def forward(self, z):
N, width, T = z.shape
z = self.preprocess(z)
assert z.shape[-1] == self.e_dim
z_flattened = z.contiguous().view(-1, self.e_dim)
# B x V
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.matmul(z_flattened, self.embedding.weight.t())
# B x 1
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q - z.detach())**2) + self.beta * \
torch.mean((z_q.detach() - z)**2)
# preserve gradients
z_q = z + (z_q - z).detach()
z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype)
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10)))
return z_q, loss, perplexity
def quantize(self, z):
assert z.shape[-1] == self.e_dim
# B x V
d = torch.sum(z ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
torch.matmul(z, self.embedding.weight.t())
# B x 1
min_encoding_indices = torch.argmin(d, dim=1)
return min_encoding_indices
def dequantize(self, indices):
index_flattened = indices.view(-1)
z_q = self.embedding(index_flattened)
z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous()
return z_q
def preprocess(self, x):
# NCT -> NTC -> [NT, C]
x = x.permute(0, 2, 1).contiguous()
x = x.view(-1, x.shape[-1])
return x
class QuantizeReset(nn.Module):
def __init__(self, nb_code, code_dim, args):
super().__init__()
self.nb_code = nb_code
self.code_dim = code_dim
self.reset_codebook()
self.codebook = nn.Parameter(torch.randn(nb_code, code_dim))
def reset_codebook(self):
self.init = False
self.code_count = None
def _tile(self, x):
nb_code_x, code_dim = x.shape
if nb_code_x < self.nb_code:
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
std = 0.01 / np.sqrt(code_dim)
out = x.repeat(n_repeats, 1)
out = out + torch.randn_like(out) * std
else :
out = x
return out
def init_codebook(self, x):
out = self._tile(x)
self.codebook = nn.Parameter(out[:self.nb_code])
self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
self.init = True
@torch.no_grad()
def compute_perplexity(self, code_idx) :
# Calculate new centres
code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
code_count = code_onehot.sum(dim=-1) # nb_code
prob = code_count / torch.sum(code_count)
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
return perplexity
def update_codebook(self, x, code_idx):
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
code_count = code_onehot.sum(dim=-1) # nb_code
out = self._tile(x)
code_rand = out[:self.nb_code]
# Update centres
self.code_count = code_count # nb_code
usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
self.codebook.data = usage * self.codebook.data + (1 - usage) * code_rand
prob = code_count / torch.sum(code_count)
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
return perplexity
def preprocess(self, x):
# NCT -> NTC -> [NT, C]
x = x.permute(0, 2, 1).contiguous()
x = x.view(-1, x.shape[-1])
return x
def quantize(self, x):
# Calculate latent code x_l
k_w = self.codebook.t()
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
keepdim=True) # (N * L, b)
_, code_idx = torch.min(distance, dim=-1)
return code_idx
def dequantize(self, code_idx):
x = F.embedding(code_idx, self.codebook)
return x
def forward(self, x):
N, width, T = x.shape
# Preprocess
x = self.preprocess(x)
# Init codebook if not inited
if self.training and not self.init:
self.init_codebook(x)
# quantize and dequantize through bottleneck
code_idx = self.quantize(x)
x_d = self.dequantize(code_idx)
# Update embeddings
if self.training:
perplexity = self.update_codebook(x, code_idx)
else :
perplexity = self.compute_perplexity(code_idx)
# Loss
commit_loss = F.mse_loss(x, x_d.detach())
# Passthrough
x_d = x + (x_d - x).detach()
# Postprocess
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
return x_d, commit_loss, perplexity
class QuantizeEMA(nn.Module):
def __init__(self, nb_code, code_dim, args):
super().__init__()
self.nb_code = nb_code
self.code_dim = code_dim
self.mu = 0.99
self.reset_codebook()
def reset_codebook(self):
self.init = False
self.code_sum = None
self.code_count = None
self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda())
def _tile(self, x):
nb_code_x, code_dim = x.shape
if nb_code_x < self.nb_code:
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
std = 0.01 / np.sqrt(code_dim)
out = x.repeat(n_repeats, 1)
out = out + torch.randn_like(out) * std
else :
out = x
return out
def init_codebook(self, x):
out = self._tile(x)
self.codebook = out[:self.nb_code]
self.code_sum = self.codebook.clone()
self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
self.init = True
@torch.no_grad()
def compute_perplexity(self, code_idx) :
# Calculate new centres
code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
code_count = code_onehot.sum(dim=-1) # nb_code
prob = code_count / torch.sum(code_count)
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
return perplexity
@torch.no_grad()
def update_codebook(self, x, code_idx):
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
code_sum = torch.matmul(code_onehot, x) # nb_code, w
code_count = code_onehot.sum(dim=-1) # nb_code
# Update centres
self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
self.codebook = code_update
prob = code_count / torch.sum(code_count)
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
return perplexity
def preprocess(self, x):
# NCT -> NTC -> [NT, C]
x = x.permute(0, 2, 1).contiguous()
x = x.view(-1, x.shape[-1])
return x
def quantize(self, x):
# Calculate latent code x_l
k_w = self.codebook.t()
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
keepdim=True) # (N * L, b)
_, code_idx = torch.min(distance, dim=-1)
return code_idx
def dequantize(self, code_idx):
x = F.embedding(code_idx, self.codebook)
return x
def forward(self, x):
N, width, T = x.shape
# Preprocess
x = self.preprocess(x)
# Init codebook if not inited
if self.training and not self.init:
self.init_codebook(x)
# quantize and dequantize through bottleneck
code_idx = self.quantize(x)
x_d = self.dequantize(code_idx)
# Update embeddings
if self.training:
perplexity = self.update_codebook(x, code_idx)
else :
perplexity = self.compute_perplexity(code_idx)
# Loss
commit_loss = F.mse_loss(x, x_d.detach())
# Passthrough
x_d = x + (x_d - x).detach()
# Postprocess
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
return x_d, commit_loss, perplexity
================================================
FILE: models/resnet.py
================================================
import torch.nn as nn
import torch
class nonlinearity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
# swish
return x * torch.sigmoid(x)
class ResConv1DBlock(nn.Module):
def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=None):
super().__init__()
padding = dilation
self.norm = norm
if norm == "LN":
self.norm1 = nn.LayerNorm(n_in)
self.norm2 = nn.LayerNorm(n_in)
elif norm == "GN":
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
elif norm == "BN":
self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
else:
self.norm1 = nn.Identity()
self.norm2 = nn.Identity()
if activation == "relu":
self.activation1 = nn.ReLU()
self.activation2 = nn.ReLU()
elif activation == "silu":
self.activation1 = nonlinearity()
self.activation2 = nonlinearity()
elif activation == "gelu":
self.activation1 = nn.GELU()
self.activation2 = nn.GELU()
self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation)
self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0,)
def forward(self, x):
x_orig = x
if self.norm == "LN":
x = self.norm1(x.transpose(-2, -1))
x = self.activation1(x.transpose(-2, -1))
else:
x = self.norm1(x)
x = self.activation1(x)
x = self.conv1(x)
if self.norm == "LN":
x = self.norm2(x.transpose(-2, -1))
x = self.activation2(x.transpose(-2, -1))
else:
x = self.norm2(x)
x = self.activation2(x)
x = self.conv2(x)
x = x + x_orig
return x
class Resnet1D(nn.Module):
def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None):
super().__init__()
blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm) for depth in range(n_depth)]
if reverse_dilation:
blocks = blocks[::-1]
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
================================================
FILE: models/rotation2xyz.py
================================================
# This code is based on https://github.com/Mathux/ACTOR.git
import torch
import utils.rotation_conversions as geometry
from models.smpl import SMPL, JOINTSTYPE_ROOT
# from .get_model import JOINTSTYPES
JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"]
class Rotation2xyz:
def __init__(self, device, dataset='amass'):
self.device = device
self.dataset = dataset
self.smpl_model = SMPL().eval().to(device)
def __call__(self, x, mask, pose_rep, translation, glob,
jointstype, vertstrans, betas=None, beta=0,
glob_rot=None, get_rotations_back=False, **kwargs):
if pose_rep == "xyz":
return x
if mask is None:
mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device)
if not glob and glob_rot is None:
raise TypeError("You must specify global rotation if glob is False")
if jointstype not in JOINTSTYPES:
raise NotImplementedError("This jointstype is not implemented.")
if translation:
x_translations = x[:, -1, :3]
x_rotations = x[:, :-1]
else:
x_rotations = x
x_rotations = x_rotations.permute(0, 3, 1, 2)
nsamples, time, njoints, feats = x_rotations.shape
# Compute rotations (convert only masked sequences output)
if pose_rep == "rotvec":
rotations = geometry.axis_angle_to_matrix(x_rotations[mask])
elif pose_rep == "rotmat":
rotations = x_rotations[mask].view(-1, njoints, 3, 3)
elif pose_rep == "rotquat":
rotations = geometry.quaternion_to_matrix(x_rotations[mask])
elif pose_rep == "rot6d":
rotations = geometry.rotation_6d_to_matrix(x_rotations[mask])
else:
raise NotImplementedError("No geometry for this one.")
if not glob:
global_orient = torch.tensor(glob_rot, device=x.device)
global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3)
global_orient = global_orient.repeat(len(rotations), 1, 1, 1)
else:
global_orient = rotations[:, 0]
rotations = rotations[:, 1:]
if betas is None:
betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas],
dtype=rotations.dtype, device=rotations.device)
betas[:, 1] = beta
# import ipdb; ipdb.set_trace()
out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas)
# get the desirable joints
joints = out[jointstype]
x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype)
x_xyz[~mask] = 0
x_xyz[mask] = joints
x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous()
# the first translation root at the origin on the prediction
if jointstype != "vertices":
rootindex = JOINTSTYPE_ROOT[jointstype]
x_xyz = x_xyz - x_xyz[:, [rootindex], :, :]
if translation and vertstrans:
# the first translation root at the origin
x_translations = x_translations - x_translations[:, :, [0]]
# add the translation to all the joints
x_xyz = x_xyz + x_translations[:, None, :, :]
if get_rotations_back:
return x_xyz, rotations, global_orient
else:
return x_xyz
================================================
FILE: models/smpl.py
================================================
# This code is based on https://github.com/Mathux/ACTOR.git
import numpy as np
import torch
import contextlib
from smplx import SMPLLayer as _SMPLLayer
from smplx.lbs import vertices2joints
# action2motion_joints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 21, 24, 38]
# change 0 and 8
action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38]
from utils.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA
JOINTSTYPE_ROOT = {"a2m": 0, # action2motion
"smpl": 0,
"a2mpl": 0, # set(smpl, a2m)
"vibe": 8} # 0 is the 8 position: OP MidHip below
JOINT_MAP = {
'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,
'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,
'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,
'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,
'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,
'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,
'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,
'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,
'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,
'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,
'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,
'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,
'Neck (LSP)': 47, 'Top of Head (LSP)': 48,
'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
'Spine (H36M)': 51, 'Jaw (H36M)': 52,
'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
}
JOINT_NAMES = [
'OP Nose', 'OP Neck', 'OP RShoulder',
'OP RElbow', 'OP RWrist', 'OP LShoulder',
'OP LElbow', 'OP LWrist', 'OP MidHip',
'OP RHip', 'OP RKnee', 'OP RAnkle',
'OP LHip', 'OP LKnee', 'OP LAnkle',
'OP REye', 'OP LEye', 'OP REar',
'OP LEar', 'OP LBigToe', 'OP LSmallToe',
'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel',
'Right Ankle', 'Right Knee', 'Right Hip',
'Left Hip', 'Left Knee', 'Left Ankle',
'Right Wrist', 'Right Elbow', 'Right Shoulder',
'Left Shoulder', 'Left Elbow', 'Left Wrist',
'Neck (LSP)', 'Top of Head (LSP)',
'Pelvis (MPII)', 'Thorax (MPII)',
'Spine (H36M)', 'Jaw (H36M)',
'Head (H36M)', 'Nose', 'Left Eye',
'Right Eye', 'Left Ear', 'Right Ear'
]
# adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints
class SMPL(_SMPLLayer):
""" Extension of the official SMPL implementation to support more joints """
def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs):
kwargs["model_path"] = model_path
# remove the verbosity for the 10-shapes beta parameters
with contextlib.redirect_stdout(None):
super(SMPL, self).__init__(**kwargs)
J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA)
self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES])
a2m_indexes = vibe_indexes[action2motion_joints]
smpl_indexes = np.arange(24)
a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes])
self.maps = {"vibe": vibe_indexes,
"a2m": a2m_indexes,
"smpl": smpl_indexes,
"a2mpl": a2mpl_indexes}
def forward(self, *args, **kwargs):
smpl_output = super(SMPL, self).forward(*args, **kwargs)
extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
output = {"vertices": smpl_output.vertices}
for joinstype, indexes in self.maps.items():
output[joinstype] = all_joints[:, indexes]
return output
================================================
FILE: models/vqvae.py
================================================
# This code is based on https://github.com/Mael-zys/T2M-GPT.git
import torch.nn as nn
from models.encdec import Encoder, Decoder
from models.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset
class VQVAE_251(nn.Module):
def __init__(self,
args,
nb_code=1024,
code_dim=512,
output_emb_width=512,
down_t=3,
stride_t=2,
width=512,
depth=3,
dilation_growth_rate=3,
activation='relu',
norm=None):
super().__init__()
self.code_dim = code_dim
self.num_code = nb_code
self.quant = args.quantizer
self.encoder = Encoder(251 if args.dataname == 'kit' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
self.decoder = Decoder(251 if args.dataname == 'kit' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
if args.quantizer == "ema_reset":
self.quantizer = QuantizeEMAReset(nb_code, code_dim, args)
elif args.quantizer == "orig":
self.quantizer = Quantizer(nb_code, code_dim, 1.0)
elif args.quantizer == "ema":
self.quantizer = QuantizeEMA(nb_code, code_dim, args)
elif args.quantizer == "reset":
self.quantizer = QuantizeReset(nb_code, code_dim, args)
def preprocess(self, x):
# (bs, T, Jx3) -> (bs, Jx3, T)
x = x.permute(0,2,1).float()
return x
def postprocess(self, x):
# (bs, Jx3, T) -> (bs, T, Jx3)
x = x.permute(0,2,1)
return x
def encode(self, x):
N, T, _ = x.shape
x_in = self.preprocess(x)
x_encoder = self.encoder(x_in)
x_encoder = self.postprocess(x_encoder)
x_encoder = x_encoder.contiguous().view(-1, x_encoder.shape[-1]) # (NT, C)
code_idx = self.quantizer.quantize(x_encoder)
code_idx = code_idx.view(N, -1)
return code_idx
def forward(self, x):
x_in = self.preprocess(x)
# Encode
x_encoder = self.encoder(x_in)
## quantization
x_quantized, loss, perplexity = self.quantizer(x_encoder)
## decoder
x_decoder = self.decoder(x_quantized)
x_out = self.postprocess(x_decoder)
return x_out, loss, perplexity
def forward_decoder(self, x):
x_d = self.quantizer.dequantize(x)
x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous()
# decoder
x_decoder = self.decoder(x_d)
x_out = self.postprocess(x_decoder)
return x_out
class HumanVQVAE(nn.Module):
def __init__(self,
args,
nb_code=512,
code_dim=512,
output_emb_width=512,
down_t=3,
stride_t=2,
width=512,
depth=3,
dilation_growth_rate=3,
activation='relu',
norm=None):
super().__init__()
self.nb_joints = 21 if args.dataname == 'kit' else 22
self.vqvae = VQVAE_251(args, nb_code, code_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
def encode(self, x):
b, t, c = x.size()
quants = self.vqvae.encode(x) # (N, T)
return quants
def forward(self, x):
x_out, loss, perplexity = self.vqvae(x)
return x_out, loss, perplexity
def forward_decoder(self, x):
x_out = self.vqvae.forward_decoder(x)
return x_out
================================================
FILE: options/get_eval_option.py
================================================
from argparse import Namespace
import re
from os.path import join as pjoin
def is_float(numStr):
flag = False
numStr = str(numStr).strip().lstrip('-').lstrip('+')
try:
reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$')
res = reg.match(str(numStr))
if res:
flag = True
except Exception as ex:
print("is_float() - error: " + str(ex))
return flag
def is_number(numStr):
flag = False
numStr = str(numStr).strip().lstrip('-').lstrip('+')
if str(numStr).isdigit():
flag = True
return flag
def get_opt(opt_path, device):
opt = Namespace()
opt_dict = vars(opt)
skip = ('-------------- End ----------------',
'------------ Options -------------',
'\n')
print('Reading', opt_path)
with open(opt_path) as f:
for line in f:
if line.strip() not in skip:
# print(line.strip())
key, value = line.strip().split(': ')
if value in ('True', 'False'):
opt_dict[key] = (value == 'True')
# print(key, value)
elif is_float(value):
opt_dict[key] = float(value)
elif is_number(value):
opt_dict[key] = int(value)
else:
opt_dict[key] = str(value)
# print(opt)
opt_dict['which_epoch'] = 'finest'
opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
opt.model_dir = pjoin(opt.save_root, 'model')
opt.meta_dir = pjoin(opt.save_root, 'meta')
if opt.dataset_name == 't2m':
opt.data_root = './dataset/HumanML3D/'
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
opt.text_dir = pjoin(opt.data_root, 'texts')
opt.joints_num = 22
opt.dim_pose = 263
opt.max_motion_length = 196
opt.max_motion_frame = 196
opt.max_motion_token = 55
elif opt.dataset_name == 'kit':
opt.data_root = './dataset/KIT-ML/'
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
opt.text_dir = pjoin(opt.data_root, 'texts')
opt.joints_num = 21
opt.dim_pose = 251
opt.max_motion_length = 196
opt.max_motion_frame = 196
opt.max_motion_token = 55
else:
raise KeyError('Dataset not recognized')
opt.dim_word = 300
opt.num_classes = 200 // opt.unit_length
opt.is_train = False
opt.is_continue = False
opt.device = device
return opt
================================================
FILE: options/option.py
================================================
import argparse
def get_args_parser():
parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder training for Amass',
add_help=True,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
## dataloader
parser.add_argument('--prompt', type=str, default="Generate a sequence of motion tokens matching the following human motion description.", help='task description')
parser.add_argument('--input', type=str, help='generation condictions')
parser.add_argument('--dataname', type=str, default='t2m', help='dataset directory')
parser.add_argument('--pretrained_llama', type=str, default="13B")
parser.add_argument('--out_dir', type=str, default='./out/', help='output directory')
parser.add_argument('--vqvae_pth', type=str, default='./checkpoints/pretrained_vqvae/t2m.pth', help='path to the pretrained vqvae pth')
parser.add_argument('--resume_pth', type=str, help='path to saved finetuned model')
parser.add_argument('--lora_path', type=str, help='path to fintuned model for evaluation')
parser.add_argument('--data_dir', type=str, default='./data/', help='dataset directory')
## lora
parser.add_argument('--lora_r', type=int, default=64)
parser.add_argument('--lora_alpha', type=int, default=16)
parser.add_argument('--lora_dropout', type=float, default=0.05)
## llama
parser.add_argument('--block_size', type=int, default=512)
## train
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--micro_batch_size', type=int, default=4, help='micro batch size')
parser.add_argument('--learning_rate', type=float, default=3e-3, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=0.01, help='weight decay')
parser.add_argument('--warmup_steps', type=int, default=100, help='warmup steps')
parser.add_argument('--eval_interval', type=int, default=100, help='evaluation frequency')
parser.add_argument('--save_interval', type=int, default=100, help='model save frequency')
parser.add_argument('--eval_iters', type=int, default=100, help='number of evaluation ierations')
parser.add_argument('--log_interval', type=int, default=1, help='log frequency')
## vqvae
parser.add_argument("--code_dim", type=int, default=512, help="embedding dimension")
parser.add_argument("--nb_code", type=int, default=512, help="nb of embedding")
parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook")
parser.add_argument("--down_t", type=int, default=2, help="downsampling rate")
parser.add_argument("--stride_t", type=int, default=2, help="stride size")
parser.add_argument("--width", type=int, default=512, help="width of the network")
parser.add_argument("--depth", type=int, default=3, help="depth of the network")
parser.add_argument("--dilation_growth_rate", type=int, default=3, help="dilation growth rate")
parser.add_argument("--output_emb_width", type=int, default=512, help="output embedding width")
parser.add_argument('--vq_act', type=str, default='relu', choices = ['relu', 'silu', 'gelu'], help='dataset directory')
parser.add_argument('--seed', default=123, type=int, help='seed for initializing vqvae training.')
parser.add_argument('--window_size', type=int, default=64, help='training motion length')
## quantizer
parser.add_argument("--quantizer", type=str, default='ema_reset', choices = ['ema', 'orig', 'ema_reset', 'reset'], help="eps for optimal transport")
parser.add_argument('--quantbeta', type=float, default=1.0, help='dataset directory')
## visualization
parser.add_argument("--render", action='store_true', help='render smpl')
return parser.parse_args()
================================================
FILE: options/option_vqvae.py
================================================
import argparse
def get_args_parser():
parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder training for Amass',
add_help=True,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
## dataloader
parser.add_argument('--dataname', type=str, default='t2m', help='dataset directory')
parser.add_argument('--out_dir', type=str, default='./out/', help='output directory')
parser.add_argument('--resume_pth', type=str, help='path to saved vqvae model')
parser.add_argument('--window_size', type=int, default=64, help='training motion length')
## train
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--learning_rate', type=float, default=2e-4, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay')
parser.add_argument('--warmup_steps', type=int, default=1000, help='number of total iterations for warmup')
parser.add_argument('--total_iter', default=300000, type=int, help='number of total iterations to run')
parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate')
parser.add_argument('--lr_scheduler', default=[200000], nargs="+", type=int, help="learning rate schedule (iterations)")
parser.add_argument('--gamma', default=0.05, type=float, help="learning rate decay")
parser.add_argument("--commit", type=float, default=0.02, help="hyper-parameter for the commitment loss")
parser.add_argument('--loss_vel', type=float, default=0.5, help='hyper-parameter for the velocity loss')
parser.add_argument('--recons_loss', type=str, default='l1_smooth', help='reconstruction loss')
parser.add_argument('--print_iter', default=200, type=int, help='print frequency')
parser.add_argument('--eval_iter', default=1000, type=int, help='evaluation frequency')
parser.add_argument('--seed', default=123, type=int, help='seed for initializing training.')
## model
parser.add_argument("--code_dim", type=int, default=512, help="embedding dimension")
parser.add_argument("--nb_code", type=int, default=512, help="nb of embedding")
parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook")
parser.add_argument("--down_t", type=int, default=2, help="downsampling rate")
parser.add_argument("--stride_t", type=int, default=2, help="stride size")
parser.add_argument("--width", type=int, default=512, help="width of the network")
parser.add_argument("--depth", type=int, default=3, help="depth of the network")
parser.add_argument("--dilation_growth_rate", type=int, default=3, help="dilation growth rate")
parser.add_argument("--output_emb_width", type=int, default=512, help="output embedding width")
parser.add_argument('--vq_act', type=str, default='relu', choices = ['relu', 'silu', 'gelu'], help='dataset directory')
parser.add_argument('--vq_norm', type=str, default=None, help='dataset directory')
## quantizer
parser.add_argument("--quantizer", type=str, default='ema_reset', choices = ['ema', 'orig', 'ema_reset', 'reset'], help="eps for optimal transport")
parser.add_argument('--beta', type=float, default=1.0, help='commitment loss in standard VQ')
return parser.parse_args()
================================================
FILE: prepare/download_evaluators.sh
================================================
mkdir -p checkpoints
cd checkpoints
echo "The evaluators will be stored in the './checkpoints' folder"
echo "Downloading"
gdown "https://drive.google.com/uc?id=1jD08gNAU2zVKDAMVyRbxzzv2uA9ssqMk"
gdown "https://drive.google.com/uc?id=1caLMTO5EMZoaCY2U7yEgZp3dG1seyNKF"
echo "Extracting"
unzip t2m.zip
unzip kit.zip
echo "Cleaning"
rm t2m.zip
rm kit.zip
echo "Downloading done!"
================================================
FILE: prepare/download_glove.sh
================================================
echo "The glove will be stored in the './' folder"
echo "Down
gitextract_f4zctl15/
├── README.md
├── dataloader/
│ ├── eval_loader.py
│ ├── tokenizer_loader.py
│ └── vqvae_loader.py
├── environment.yml
├── eval.py
├── eval_vqvae.py
├── finetune_motion.py
├── generate.py
├── generate_batch.py
├── generate_motion.py
├── index.html
├── lit_llama/
│ ├── __init__.py
│ ├── adapter.py
│ ├── indexed_dataset.py
│ ├── lora.py
│ ├── model.py
│ ├── quantization.py
│ ├── tokenizer.py
│ └── utils.py
├── models/
│ ├── encdec.py
│ ├── evaluator_wrapper.py
│ ├── modules.py
│ ├── quantize_cnn.py
│ ├── resnet.py
│ ├── rotation2xyz.py
│ ├── smpl.py
│ └── vqvae.py
├── options/
│ ├── get_eval_option.py
│ ├── option.py
│ └── option_vqvae.py
├── prepare/
│ ├── download_evaluators.sh
│ ├── download_glove.sh
│ ├── download_lora.sh
│ ├── download_smpl.sh
│ └── download_vqvae.sh
├── scripts/
│ ├── convert_checkpoint.py
│ ├── convert_hf_checkpoint.py
│ ├── download.py
│ ├── generate_dataset.py
│ ├── prepare_data.py
│ └── prepare_motion.py
├── sitemap.xml
├── static/
│ ├── css/
│ │ ├── bulma.css.map.txt
│ │ └── index.css
│ └── js/
│ ├── bulma-carousel.js
│ ├── bulma-slider.js
│ └── index.js
├── train_vqvae.py
├── utils/
│ ├── config.py
│ ├── evaluate.py
│ ├── losses.py
│ ├── motion_process.py
│ ├── paramUtil.py
│ ├── quaternion.py
│ ├── rotation_conversions.py
│ ├── skeleton.py
│ ├── utils_model.py
│ └── word_vectorizer.py
├── visualization/
│ ├── plot_3d_global.py
│ └── render.py
└── visualize/
├── joints2smpl/
│ ├── smpl_models/
│ │ ├── SMPL_downsample_index.pkl
│ │ ├── gmm_08.pkl
│ │ ├── neutral_smpl_mean_params.h5
│ │ └── smplx_parts_segm.pkl
│ └── src/
│ ├── config.py
│ ├── customloss.py
│ ├── prior.py
│ └── smplify.py
├── render_mesh.py
├── simplify_loc2rot.py
└── vis_utils.py
SYMBOL INDEX (491 symbols across 50 files)
FILE: dataloader/eval_loader.py
function collate_fn (line 13) | def collate_fn(batch):
class Text2MotionDataset (line 19) | class Text2MotionDataset(data.Dataset):
method __init__ (line 20) | def __init__(self, dataset_name, split, w_vectorizer, feat_bias = 5, m...
method reset_max_len (line 127) | def reset_max_len(self, length):
method inv_transform (line 133) | def inv_transform(self, data):
method forward_transform (line 136) | def forward_transform(self, data):
method __len__ (line 139) | def __len__(self):
method __getitem__ (line 142) | def __getitem__(self, item):
function DATALoader (line 195) | def DATALoader(dataset_name, split,
function cycle (line 208) | def cycle(iterable):
FILE: dataloader/tokenizer_loader.py
class VQMotionDataset (line 10) | class VQMotionDataset(data.Dataset):
method __init__ (line 11) | def __init__(self, dataset_name, feat_bias = 5, window_size = 64, unit...
method inv_transform (line 77) | def inv_transform(self, data):
method __len__ (line 80) | def __len__(self):
method __getitem__ (line 83) | def __getitem__(self, item):
function DATALoader (line 98) | def DATALoader(dataset_name,
function cycle (line 110) | def cycle(iterable):
FILE: dataloader/vqvae_loader.py
class VQMotionDataset (line 10) | class VQMotionDataset(data.Dataset):
method __init__ (line 11) | def __init__(self, dataset_name, window_size = 64, unit_length = 4):
method inv_transform (line 62) | def inv_transform(self, data):
method compute_sampling_prob (line 65) | def compute_sampling_prob(self) :
method __len__ (line 71) | def __len__(self):
method __getitem__ (line 74) | def __getitem__(self, item):
function DATALoader (line 85) | def DATALoader(dataset_name,
function cycle (line 98) | def cycle(iterable):
FILE: eval.py
function main (line 32) | def main(
FILE: eval_vqvae.py
function main (line 24) | def main():
FILE: finetune_motion.py
function main (line 27) | def main():
function train (line 104) | def train(
function validate (line 160) | def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndar...
function loss_fn (line 174) | def loss_fn(logits, targets):
function get_batch (line 182) | def get_batch(fabric: L.Fabric, data: list):
function load_datasets (line 206) | def load_datasets():
FILE: generate.py
function generate (line 15) | def generate(
function main (line 73) | def main(
FILE: generate_batch.py
function generate (line 15) | def generate(
function main (line 73) | def main(
FILE: generate_motion.py
function main (line 27) | def main(
FILE: lit_llama/adapter.py
class LLaMAConfig (line 18) | class LLaMAConfig(llama.LLaMAConfig):
class CausalSelfAttention (line 23) | class CausalSelfAttention(nn.Module):
method __init__ (line 27) | def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
method forward (line 50) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class Block (line 102) | class Block(nn.Module):
method __init__ (line 106) | def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
method forward (line 113) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class LLaMA (line 119) | class LLaMA(llama.LLaMA):
method __init__ (line 123) | def __init__(self, config: LLaMAConfig) -> None:
method from_name (line 139) | def from_name(cls, name: str):
function mark_only_adapter_as_trainable (line 143) | def mark_only_adapter_as_trainable(model: LLaMA) -> None:
function adapter_state_from_state_dict (line 149) | def adapter_state_from_state_dict(state_dict: dict) -> dict:
FILE: lit_llama/indexed_dataset.py
function __best_fitting_dtype (line 24) | def __best_fitting_dtype(vocab_size=None):
function get_available_dataset_impl (line 31) | def get_available_dataset_impl():
function infer_dataset_impl (line 35) | def infer_dataset_impl(path):
function make_builder (line 51) | def make_builder(out_file, impl, vocab_size=None):
function make_dataset (line 58) | def make_dataset(path, impl, skip_warmup=False):
function dataset_exists (line 75) | def dataset_exists(path, impl):
function read_longs (line 82) | def read_longs(f, n):
function write_longs (line 88) | def write_longs(f, a):
function code (line 104) | def code(dtype):
function index_file_path (line 111) | def index_file_path(prefix_path):
function data_file_path (line 115) | def data_file_path(prefix_path):
function create_doc_idx (line 119) | def create_doc_idx(sizes):
class IndexedDataset (line 127) | class IndexedDataset(torch.utils.data.Dataset):
method __init__ (line 131) | def __init__(self, path):
method read_index (line 137) | def read_index(self, path):
method read_data (line 155) | def read_data(self, path):
method check_index (line 158) | def check_index(self, i):
method __del__ (line 162) | def __del__(self):
method __getitem__ (line 167) | def __getitem__(self, idx):
method __len__ (line 191) | def __len__(self):
method num_tokens (line 194) | def num_tokens(self, index):
method size (line 197) | def size(self, index):
method exists (line 201) | def exists(path):
method supports_prefetch (line 207) | def supports_prefetch(self):
class IndexedCachedDataset (line 211) | class IndexedCachedDataset(IndexedDataset):
method __init__ (line 213) | def __init__(self, path):
method supports_prefetch (line 219) | def supports_prefetch(self):
method prefetch (line 222) | def prefetch(self, indices):
method __getitem__ (line 247) | def __getitem__(self, idx):
class IndexedDatasetBuilder (line 264) | class IndexedDatasetBuilder(object):
method __init__ (line 275) | def __init__(self, out_file, dtype=np.int32):
method add_item (line 284) | def add_item(self, tensor):
method end_document (line 291) | def end_document(self):
method merge_file_ (line 294) | def merge_file_(self, another_file):
method finalize (line 319) | def finalize(self, index_file):
function _warmup_mmap_file (line 334) | def _warmup_mmap_file(path):
class MMapIndexedDataset (line 340) | class MMapIndexedDataset(torch.utils.data.Dataset):
class Index (line 341) | class Index(object):
method writer (line 345) | def writer(cls, path, dtype):
method __init__ (line 390) | def __init__(self, path, skip_warmup=False):
method __del__ (line 427) | def __del__(self):
method dtype (line 432) | def dtype(self):
method sizes (line 436) | def sizes(self):
method doc_idx (line 440) | def doc_idx(self):
method __getitem__ (line 444) | def __getitem__(self, i):
method __len__ (line 447) | def __len__(self):
method __init__ (line 450) | def __init__(self, path, skip_warmup=False):
method __getstate__ (line 459) | def __getstate__(self):
method __setstate__ (line 462) | def __setstate__(self, state):
method _do_init (line 465) | def _do_init(self, path, skip_warmup):
method __del__ (line 477) | def __del__(self):
method __len__ (line 482) | def __len__(self):
method __getitem__ (line 486) | def __getitem__(self, idx):
method get (line 507) | def get(self, idx, offset=0, length=None):
method sizes (line 522) | def sizes(self):
method doc_idx (line 526) | def doc_idx(self):
method get_doc_idx (line 529) | def get_doc_idx(self):
method set_doc_idx (line 532) | def set_doc_idx(self, doc_idx_):
method supports_prefetch (line 536) | def supports_prefetch(self):
method exists (line 540) | def exists(path):
class MMapIndexedDatasetBuilder (line 546) | class MMapIndexedDatasetBuilder(object):
method __init__ (line 547) | def __init__(self, out_file, dtype=np.int64):
method dtype (line 554) | def dtype(self):
method add_item (line 557) | def add_item(self, np_array):
method add_doc (line 562) | def add_doc(self, np_array, sizes):
method end_document (line 568) | def end_document(self):
method merge_file_ (line 571) | def merge_file_(self, another_file):
method finalize (line 584) | def finalize(self, index_file):
FILE: lit_llama/lora.py
class LoRALayer (line 18) | class LoRALayer():
method __init__ (line 19) | def __init__(
class MergedLinear (line 38) | class MergedLinear(nn.Linear, LoRALayer):
method __init__ (line 40) | def __init__(
method reset_parameters (line 79) | def reset_parameters(self):
method zero_pad (line 86) | def zero_pad(self, x):
method train (line 94) | def train(self, mode: bool = True):
method eval (line 109) | def eval(self):
method forward (line 124) | def forward(self, x: torch.Tensor):
function mark_only_lora_as_trainable (line 142) | def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') ->...
function lora_state_dict (line 162) | def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, t...
class LoRAConfig (line 182) | class LoRAConfig:
class CausalSelfAttention (line 188) | class CausalSelfAttention(llama.CausalSelfAttention):
method __init__ (line 191) | def __init__(self, config: llama.LLaMAConfig) -> None:
function lora (line 218) | def lora(r, alpha, dropout, enabled: bool = True):
FILE: lit_llama/model.py
class LLaMAConfig (line 16) | class LLaMAConfig:
method from_name (line 24) | def from_name(cls, name: str) -> Self:
class LLaMA (line 36) | class LLaMA(nn.Module):
method __init__ (line 37) | def __init__(self, config: LLaMAConfig) -> None:
method _init_weights (line 64) | def _init_weights(self, module: nn.Module) -> None:
method forward (line 70) | def forward(self, idx: torch.Tensor) -> torch.Tensor:
method from_name (line 88) | def from_name(cls, name: str) -> Self:
class Block (line 92) | class Block(nn.Module):
method __init__ (line 93) | def __init__(self, config: LLaMAConfig) -> None:
method forward (line 100) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class CausalSelfAttention (line 106) | class CausalSelfAttention(nn.Module):
method __init__ (line 107) | def __init__(self, config: LLaMAConfig) -> None:
method forward (line 121) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class MLP (line 161) | class MLP(nn.Module):
method __init__ (line 162) | def __init__(self, config: LLaMAConfig) -> None:
method forward (line 174) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class RMSNorm (line 180) | class RMSNorm(nn.Module):
method __init__ (line 187) | def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
method forward (line 193) | def forward(self, x: torch.Tensor) -> torch.Tensor:
function build_rope_cache (line 203) | def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, devi...
function apply_rope (line 234) | def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
FILE: lit_llama/quantization.py
class Linear8bitLt (line 29) | class Linear8bitLt(bnb.nn.Linear8bitLt):
method __init__ (line 36) | def __init__(self, *args, **kwargs):
method _load_from_state_dict (line 42) | def _load_from_state_dict(self, local_state_dict, *args, **kwargs):
method _quantize_weight (line 56) | def _quantize_weight(self, weight: torch.Tensor) -> None:
class ColBlockQuantizedLinear (line 68) | class ColBlockQuantizedLinear(torch.nn.Module):
method __init__ (line 69) | def __init__(self, in_features, out_features, bias: bool, *, bits, til...
method pack_weight (line 87) | def pack_weight(self, weight):
method get_weight (line 97) | def get_weight(self, dtype=torch.float):
method forward (line 108) | def forward(self, inp):
class GPTQQuantizer (line 115) | class GPTQQuantizer:
method __init__ (line 121) | def __init__(self, linear_module, *, bits, perchannel=True, sym=False,...
method quantize_weight (line 144) | def quantize_weight(x, scale, zero, maxq):
method find_params_weight (line 149) | def find_params_weight(self, x):
method collect_input_stats (line 187) | def collect_input_stats(self, _1, inp, _2):
method quantize (line 203) | def quantize(self):
FILE: lit_llama/tokenizer.py
class Tokenizer (line 9) | class Tokenizer:
method __init__ (line 12) | def __init__(self, model_path: Path) -> None:
method vocab_size (line 19) | def vocab_size(self) -> int:
method encode (line 22) | def encode(
method decode (line 43) | def decode(self, tokens: torch.Tensor) -> str:
method train (line 47) | def train(input: str, destination: str, vocab_size=32000) -> None:
FILE: lit_llama/utils.py
function save_model_checkpoint (line 17) | def save_model_checkpoint(fabric, model, file_path):
class EmptyInitOnDevice (line 46) | class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
method __init__ (line 47) | def __init__(self, device=None, dtype=None, quantization_mode=None):
method __enter__ (line 81) | def __enter__(self):
method __exit__ (line 87) | def __exit__(self, exc_type, exc_val, exc_tb):
method __torch_function__ (line 92) | def __torch_function__(self, func, types, args=(), kwargs=None):
class NotYetLoadedTensor (line 117) | class NotYetLoadedTensor:
method __init__ (line 118) | def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):
method rebuild (line 125) | def rebuild(
method _load_tensor (line 156) | def _load_tensor(self):
method __torch_function__ (line 178) | def __torch_function__(cls, func, types, args=(), kwargs=None):
method __getattr__ (line 188) | def __getattr__(self, name):
method __repr__ (line 215) | def __repr__(self):
class LazyLoadingUnpickler (line 219) | class LazyLoadingUnpickler(pickle.Unpickler):
method __init__ (line 220) | def __init__(self, file, zipfile):
method find_class (line 224) | def find_class(self, module, name):
method persistent_load (line 230) | def persistent_load(self, pid):
function lazy_load (line 239) | def lazy_load(fn):
FILE: models/encdec.py
class Encoder (line 4) | class Encoder(nn.Module):
method __init__ (line 5) | def __init__(self,
method forward (line 32) | def forward(self, x):
class Decoder (line 35) | class Decoder(nn.Module):
method __init__ (line 36) | def __init__(self,
method forward (line 65) | def forward(self, x):
FILE: models/evaluator_wrapper.py
function build_models (line 8) | def build_models(opt):
class EvaluatorModelWrapper (line 30) | class EvaluatorModelWrapper(object):
method __init__ (line 32) | def __init__(self, opt):
method get_co_embeddings (line 64) | def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_...
method get_motion_embeddings (line 80) | def get_motion_embeddings(self, motions, m_lens):
FILE: models/modules.py
function init_weight (line 5) | def init_weight(m):
class MovementConvEncoder (line 13) | class MovementConvEncoder(nn.Module):
method __init__ (line 14) | def __init__(self, input_size, hidden_size, output_size):
method forward (line 28) | def forward(self, inputs):
class TextEncoderBiGRUCo (line 36) | class TextEncoderBiGRUCo(nn.Module):
method __init__ (line 37) | def __init__(self, word_size, pos_size, hidden_size, output_size, devi...
method forward (line 58) | def forward(self, word_embs, pos_onehot, cap_lens):
class MotionEncoderBiGRUCo (line 76) | class MotionEncoderBiGRUCo(nn.Module):
method __init__ (line 77) | def __init__(self, input_size, hidden_size, output_size, device):
method forward (line 96) | def forward(self, inputs, m_lens):
FILE: models/quantize_cnn.py
class QuantizeEMAReset (line 6) | class QuantizeEMAReset(nn.Module):
method __init__ (line 7) | def __init__(self, nb_code, code_dim, args):
method reset_codebook (line 14) | def reset_codebook(self):
method _tile (line 20) | def _tile(self, x):
method init_codebook (line 31) | def init_codebook(self, x):
method compute_perplexity (line 39) | def compute_perplexity(self, code_idx) :
method update_codebook (line 50) | def update_codebook(self, x, code_idx):
method preprocess (line 75) | def preprocess(self, x):
method quantize (line 81) | def quantize(self, x):
method dequantize (line 89) | def dequantize(self, code_idx):
method forward (line 94) | def forward(self, x):
class Quantizer (line 127) | class Quantizer(nn.Module):
method __init__ (line 128) | def __init__(self, n_e, e_dim, beta):
method forward (line 138) | def forward(self, z):
method quantize (line 166) | def quantize(self, z):
method dequantize (line 178) | def dequantize(self, indices):
method preprocess (line 185) | def preprocess(self, x):
class QuantizeReset (line 193) | class QuantizeReset(nn.Module):
method __init__ (line 194) | def __init__(self, nb_code, code_dim, args):
method reset_codebook (line 201) | def reset_codebook(self):
method _tile (line 205) | def _tile(self, x):
method init_codebook (line 216) | def init_codebook(self, x):
method compute_perplexity (line 223) | def compute_perplexity(self, code_idx) :
method update_codebook (line 233) | def update_codebook(self, x, code_idx):
method preprocess (line 254) | def preprocess(self, x):
method quantize (line 260) | def quantize(self, x):
method dequantize (line 268) | def dequantize(self, code_idx):
method forward (line 273) | def forward(self, x):
class QuantizeEMA (line 301) | class QuantizeEMA(nn.Module):
method __init__ (line 302) | def __init__(self, nb_code, code_dim, args):
method reset_codebook (line 309) | def reset_codebook(self):
method _tile (line 315) | def _tile(self, x):
method init_codebook (line 326) | def init_codebook(self, x):
method compute_perplexity (line 334) | def compute_perplexity(self, code_idx) :
method update_codebook (line 345) | def update_codebook(self, x, code_idx):
method preprocess (line 365) | def preprocess(self, x):
method quantize (line 371) | def quantize(self, x):
method dequantize (line 379) | def dequantize(self, code_idx):
method forward (line 384) | def forward(self, x):
FILE: models/resnet.py
class nonlinearity (line 4) | class nonlinearity(nn.Module):
method __init__ (line 5) | def __init__(self):
method forward (line 8) | def forward(self, x):
class ResConv1DBlock (line 12) | class ResConv1DBlock(nn.Module):
method __init__ (line 13) | def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=...
method forward (line 49) | def forward(self, x):
class Resnet1D (line 71) | class Resnet1D(nn.Module):
method __init__ (line 72) | def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dila...
method forward (line 81) | def forward(self, x):
FILE: models/rotation2xyz.py
class Rotation2xyz (line 11) | class Rotation2xyz:
method __init__ (line 12) | def __init__(self, device, dataset='amass'):
method __call__ (line 17) | def __call__(self, x, mask, pose_rep, translation, glob,
FILE: models/smpl.py
class SMPL (line 64) | class SMPL(_SMPLLayer):
method __init__ (line 67) | def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs):
method forward (line 86) | def forward(self, *args, **kwargs):
FILE: models/vqvae.py
class VQVAE_251 (line 7) | class VQVAE_251(nn.Module):
method __init__ (line 8) | def __init__(self,
method preprocess (line 37) | def preprocess(self, x):
method postprocess (line 43) | def postprocess(self, x):
method encode (line 49) | def encode(self, x):
method forward (line 60) | def forward(self, x):
method forward_decoder (line 75) | def forward_decoder(self, x):
class HumanVQVAE (line 86) | class HumanVQVAE(nn.Module):
method __init__ (line 87) | def __init__(self,
method encode (line 105) | def encode(self, x):
method forward (line 110) | def forward(self, x):
method forward_decoder (line 116) | def forward_decoder(self, x):
FILE: options/get_eval_option.py
function is_float (line 6) | def is_float(numStr):
function is_number (line 19) | def is_number(numStr):
function get_opt (line 27) | def get_opt(opt_path, device):
FILE: options/option.py
function get_args_parser (line 3) | def get_args_parser():
FILE: options/option_vqvae.py
function get_args_parser (line 3) | def get_args_parser():
FILE: scripts/convert_checkpoint.py
function convert_state_dict (line 20) | def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch...
function meta_weights_for_nano_model (line 66) | def meta_weights_for_nano_model(
FILE: scripts/convert_hf_checkpoint.py
function convert_hf_checkpoint (line 18) | def convert_hf_checkpoint(
FILE: scripts/download.py
function download_original (line 11) | def download_original(wd: str) -> None:
function download_from_hub (line 22) | def download_from_hub(repo_id: Optional[str] = None, local_dir: str = "c...
FILE: scripts/generate_dataset.py
function prepare (line 15) | def prepare(split):
function main (line 47) | def main():
FILE: scripts/prepare_motion.py
function prepare (line 20) | def prepare(
function prepare_sample (line 50) | def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int,...
function tokenize (line 79) | def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=Tru...
function generate_prompt (line 83) | def generate_prompt(example):
function main (line 100) | def main():
FILE: static/js/bulma-carousel.js
function __webpack_require__ (line 16) | function __webpack_require__(moduleId) {
function detectSupportsPassive (line 189) | function detectSupportsPassive() {
function defineProperties (line 273) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _toConsumableArray (line 275) | function _toConsumableArray(arr) { if (Array.isArray(arr)) { for (var i ...
function _classCallCheck (line 277) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function EventEmitter (line 280) | function EventEmitter() {
function defineProperties (line 324) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _classCallCheck (line 326) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function Coordinate (line 329) | function Coordinate() {
function defineProperties (line 473) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _defineProperty (line 475) | function _defineProperty(obj, key, value) { if (key in obj) { Object.def...
function _classCallCheck (line 477) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function _possibleConstructorReturn (line 479) | function _possibleConstructorReturn(self, call) { if (!self) { throw new...
function _inherits (line 481) | function _inherits(subClass, superClass) { if (typeof superClass !== "fu...
function bulmaCarousel (line 504) | function bulmaCarousel(selector) {
function _toConsumableArray (line 983) | function _toConsumableArray(arr) { if (Array.isArray(arr)) { for (var i ...
function defineProperties (line 1019) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _classCallCheck (line 1021) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function _possibleConstructorReturn (line 1023) | function _possibleConstructorReturn(self, call) { if (!self) { throw new...
function _inherits (line 1025) | function _inherits(subClass, superClass) { if (typeof superClass !== "fu...
function Autoplay (line 1042) | function Autoplay(slider) {
function defineProperties (line 1209) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _classCallCheck (line 1211) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function Breakpoints (line 1216) | function Breakpoints(slider) {
function defineProperties (line 1347) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _toConsumableArray (line 1349) | function _toConsumableArray(arr) { if (Array.isArray(arr)) { for (var i ...
function _classCallCheck (line 1351) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function Infinite (line 1354) | function Infinite(slider) {
function defineProperties (line 1427) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _classCallCheck (line 1429) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function Loop (line 1434) | function Loop(slider) {
function defineProperties (line 1494) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _classCallCheck (line 1496) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function Navigation (line 1502) | function Navigation(slider) {
function defineProperties (line 1644) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _classCallCheck (line 1646) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function Pagination (line 1653) | function Pagination(slider) {
function defineProperties (line 1804) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _classCallCheck (line 1806) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function Swipe (line 1812) | function Swipe(slider) {
function defineProperties (line 1946) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _classCallCheck (line 1948) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function Transitioner (line 1954) | function Transitioner(slider) {
function defineProperties (line 2047) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _classCallCheck (line 2049) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function Fade (line 2054) | function Fade(transitioner, slider) {
function defineProperties (line 2182) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _classCallCheck (line 2184) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function Translate (line 2190) | function Translate(transitioner, slider) {
FILE: static/js/bulma-slider.js
function __webpack_require__ (line 16) | function __webpack_require__(moduleId) {
function defineProperties (line 86) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _classCallCheck (line 90) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function _possibleConstructorReturn (line 92) | function _possibleConstructorReturn(self, call) { if (!self) { throw new...
function _inherits (line 94) | function _inherits(subClass, superClass) { if (typeof superClass !== "fu...
function bulmaSlider (line 105) | function bulmaSlider(selector) {
function defineProperties (line 281) | function defineProperties(target, props) { for (var i = 0; i < props.len...
function _classCallCheck (line 283) | function _classCallCheck(instance, Constructor) { if (!(instance instanc...
function EventEmitter (line 286) | function EventEmitter() {
FILE: train_vqvae.py
function update_lr_warm_up (line 21) | def update_lr_warm_up(optimizer, nb_iter, warmup_step, lr):
function main (line 33) | def main():
FILE: utils/evaluate.py
function truncate_output_to_eos (line 15) | def truncate_output_to_eos(output, eos_id):
function pad_left (line 24) | def pad_left(x, max_len, pad_id):
function plot (line 31) | def plot(tokens, net, dataname):
function vqvae_evaluation (line 44) | def vqvae_evaluation(out_dir, val_loader, net, logger, writer, eval_wrap...
function evaluation (line 159) | def evaluation(val_loader, net, model, logger, tokenizer, eval_wrapper, ...
function euclidean_distance_matrix (line 261) | def euclidean_distance_matrix(matrix1, matrix2):
function calculate_top_k (line 278) | def calculate_top_k(mat, top_k):
function calculate_R_precision (line 293) | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
function calculate_diversity (line 304) | def calculate_diversity(activation, diversity_times):
function calculate_frechet_distance (line 315) | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
function calculate_activation_statistics (line 352) | def calculate_activation_statistics(activations):
function calculate_frechet_feature_distance (line 358) | def calculate_frechet_feature_distance(feature_list1, feature_list2):
FILE: utils/losses.py
class ReConsLoss (line 4) | class ReConsLoss(nn.Module):
method __init__ (line 5) | def __init__(self, recons_loss, nb_joints):
method forward (line 22) | def forward(self, motion_pred, motion_gt) :
method forward_vel (line 26) | def forward_vel(self, motion_pred, motion_gt) :
FILE: utils/motion_process.py
function recover_root_rot_pos (line 4) | def recover_root_rot_pos(data):
function recover_from_rot (line 26) | def recover_from_rot(data, joints_num, skeleton):
function recover_from_ric (line 43) | def recover_from_ric(data, joints_num):
FILE: utils/quaternion.py
function qinv (line 16) | def qinv(q):
function qinv_np (line 23) | def qinv_np(q):
function qnormalize (line 28) | def qnormalize(q):
function qmul (line 33) | def qmul(q, r):
function qrot (line 54) | def qrot(q, v):
function qeuler (line 76) | def qeuler(q, order, epsilon=0, deg=True):
function qmul_np (line 128) | def qmul_np(q, r):
function qrot_np (line 134) | def qrot_np(q, v):
function qeuler_np (line 140) | def qeuler_np(q, order, epsilon=0, use_gpu=False):
function qfix (line 149) | def qfix(q):
function euler2quat (line 169) | def euler2quat(e, order, deg=True):
function expmap_to_quaternion (line 214) | def expmap_to_quaternion(e):
function euler_to_quaternion (line 233) | def euler_to_quaternion(e, order):
function quaternion_to_matrix (line 274) | def quaternion_to_matrix(quaternions):
function quaternion_to_matrix_np (line 303) | def quaternion_to_matrix_np(quaternions):
function quaternion_to_cont6d_np (line 308) | def quaternion_to_cont6d_np(quaternions):
function quaternion_to_cont6d (line 314) | def quaternion_to_cont6d(quaternions):
function cont6d_to_matrix (line 320) | def cont6d_to_matrix(cont6d):
function cont6d_to_matrix_np (line 339) | def cont6d_to_matrix_np(cont6d):
function qpow (line 344) | def qpow(q0, t, dtype=torch.float):
function qslerp (line 369) | def qslerp(q0, q1, t):
function qbetween (line 387) | def qbetween(v0, v1):
function qbetween_np (line 400) | def qbetween_np(v0, v1):
function lerp (line 412) | def lerp(p0, p1, t):
FILE: utils/rotation_conversions.py
function quaternion_to_matrix (line 32) | def quaternion_to_matrix(quaternions):
function _copysign (line 61) | def _copysign(a, b):
function _sqrt_positive_part (line 77) | def _sqrt_positive_part(x):
function matrix_to_quaternion (line 88) | def matrix_to_quaternion(matrix):
function _axis_angle_rotation (line 111) | def _axis_angle_rotation(axis: str, angle):
function euler_angles_to_matrix (line 137) | def euler_angles_to_matrix(euler_angles, convention: str):
function _angle_from_tan (line 160) | def _angle_from_tan(
function _index_from_letter (line 191) | def _index_from_letter(letter: str):
function matrix_to_euler_angles (line 200) | def matrix_to_euler_angles(matrix, convention: str):
function random_quaternions (line 240) | def random_quaternions(
function random_rotations (line 262) | def random_rotations(
function random_rotation (line 283) | def random_rotation(
function standardize_quaternion (line 300) | def standardize_quaternion(quaternions):
function quaternion_raw_multiply (line 313) | def quaternion_raw_multiply(a, b):
function quaternion_multiply (line 332) | def quaternion_multiply(a, b):
function quaternion_invert (line 347) | def quaternion_invert(quaternion):
function quaternion_apply (line 361) | def quaternion_apply(quaternion, point):
function axis_angle_to_matrix (line 382) | def axis_angle_to_matrix(axis_angle):
function matrix_to_axis_angle (line 396) | def matrix_to_axis_angle(matrix):
function axis_angle_to_quaternion (line 410) | def axis_angle_to_quaternion(axis_angle):
function quaternion_to_axis_angle (line 440) | def quaternion_to_axis_angle(quaternions):
function rotation_6d_to_matrix (line 469) | def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
function matrix_to_rotation_6d (line 491) | def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
function canonicalize_smplh (line 506) | def canonicalize_smplh(poses, trans = None):
FILE: utils/skeleton.py
class Skeleton (line 4) | class Skeleton(object):
method __init__ (line 5) | def __init__(self, offset, kinematic_tree, device):
method njoints (line 17) | def njoints(self):
method offset (line 20) | def offset(self):
method set_offset (line 23) | def set_offset(self, offsets):
method kinematic_tree (line 26) | def kinematic_tree(self):
method parents (line 29) | def parents(self):
method get_offsets_joints_batch (line 33) | def get_offsets_joints_batch(self, joints):
method get_offsets_joints (line 43) | def get_offsets_joints(self, joints):
method inverse_kinematics_np (line 55) | def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward...
method forward_kinematics (line 104) | def forward_kinematics(self, quat_params, root_pos, skel_joints=None, ...
method forward_kinematics_np (line 126) | def forward_kinematics_np(self, quat_params, root_pos, skel_joints=Non...
method forward_kinematics_cont6d_np (line 149) | def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_j...
method forward_kinematics_cont6d (line 173) | def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_join...
FILE: utils/utils_model.py
function getCi (line 8) | def getCi(accLog):
function get_logger (line 16) | def get_logger(out_dir):
function initial_optim (line 33) | def initial_optim(decay_option, lr, weight_decay, net, optimizer) :
function get_motion_with_trans (line 55) | def get_motion_with_trans(motion, velocity) :
FILE: utils/word_vectorizer.py
class WordVectorizer (line 46) | class WordVectorizer(object):
method __init__ (line 47) | def __init__(self, meta_root, prefix):
method _get_pos_ohot (line 53) | def _get_pos_ohot(self, pos):
method __len__ (line 61) | def __len__(self):
method __getitem__ (line 64) | def __getitem__(self, item):
class WordVectorizerV2 (line 83) | class WordVectorizerV2(WordVectorizer):
method __init__ (line 84) | def __init__(self, meta_root, prefix):
method __getitem__ (line 88) | def __getitem__(self, item):
method itos (line 96) | def itos(self, idx):
FILE: visualization/plot_3d_global.py
function plot_3d_motion (line 12) | def plot_3d_motion(args, figsize=(10, 10), fps=120, radius=4):
function draw_to_batch (line 114) | def draw_to_batch(smpl_joints_batch, title_batch=None, outname=None) :
FILE: visualization/render.py
class WeakPerspectiveCamera (line 20) | class WeakPerspectiveCamera(pyrender.Camera):
method __init__ (line 21) | def __init__(self,
method get_projection_matrix (line 35) | def get_projection_matrix(self, width=None, height=None):
function render (line 44) | def render(motions, name, outdir='render_vis', device_id=0):
FILE: visualize/joints2smpl/src/customloss.py
function gmof (line 6) | def gmof(x, sigma):
function angle_prior (line 15) | def angle_prior(pose):
function perspective_projection (line 24) | def perspective_projection(points, rotation, translation,
function body_fitting_loss (line 55) | def body_fitting_loss(body_pose, betas, model_joints, camera_t, camera_c...
function camera_fitting_loss (line 91) | def camera_fitting_loss(model_joints, camera_t, camera_t_est, camera_cen...
function body_fitting_loss_3d (line 128) | def body_fitting_loss_3d(body_pose, preserve_pose,
function camera_fitting_loss_3d (line 192) | def camera_fitting_loss_3d(model_joints, camera_t, camera_t_est,
FILE: visualize/joints2smpl/src/prior.py
function create_prior (line 35) | def create_prior(prior_type, **kwargs):
class SMPLifyAnglePrior (line 52) | class SMPLifyAnglePrior(nn.Module):
method __init__ (line 53) | def __init__(self, dtype=torch.float32, **kwargs):
method forward (line 72) | def forward(self, pose, with_global_pose=False):
class L2Prior (line 91) | class L2Prior(nn.Module):
method __init__ (line 92) | def __init__(self, dtype=DEFAULT_DTYPE, reduction='sum', **kwargs):
method forward (line 95) | def forward(self, module_input, *args):
class MaxMixturePrior (line 99) | class MaxMixturePrior(nn.Module):
method __init__ (line 101) | def __init__(self, prior_folder='prior',
method get_mean (line 175) | def get_mean(self):
method merged_log_likelihood (line 180) | def merged_log_likelihood(self, pose, betas):
method log_likelihood (line 197) | def log_likelihood(self, pose, betas, *args, **kwargs):
method forward (line 226) | def forward(self, pose, betas):
FILE: visualize/joints2smpl/src/smplify.py
function guess_init_3d (line 19) | def guess_init_3d(model_joints,
class SMPLify3D (line 44) | class SMPLify3D():
method __init__ (line 47) | def __init__(self,
method __call__ (line 95) | def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0...
FILE: visualize/simplify_loc2rot.py
class joints2smpl (line 13) | class joints2smpl:
method __init__ (line 15) | def __init__(self, num_frames, device_id, cuda=True):
method npy2smpl (line 45) | def npy2smpl(self, npy_path):
method joint2smpl (line 63) | def joint2smpl(self, input_joints, init_params=None):
FILE: visualize/vis_utils.py
class npy2obj (line 8) | class npy2obj:
method __init__ (line 9) | def __init__(self, npy_path, sample_idx, rep_idx, device=0, cuda=True):
method get_vertices (line 43) | def get_vertices(self, sample_i, frame_i):
method get_trimesh (line 46) | def get_trimesh(self, sample_i, frame_i):
method save_obj (line 50) | def save_obj(self, save_path, frame_i):
method save_npy (line 56) | def save_npy(self, save_path):
Condensed preview — 72 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (3,239K chars).
[
{
"path": "README.md",
"chars": 6894,
"preview": "# MotionGPT: Finetuned LLMs are General-Purpose Motion Generators\n\n[ Facebook, Inc. and i"
},
{
"path": "lit_llama/lora.py",
"chars": 8596,
"preview": "# Derived from https://github.com/microsoft/LoRA\n# --------------------------------------------------------------------"
},
{
"path": "lit_llama/model.py",
"chars": 9017,
"preview": "\"\"\"Full definition of a LLaMA Language Model, all of it in this single file.\n\nBased on the nanoGPT implementation: https"
},
{
"path": "lit_llama/quantization.py",
"chars": 11122,
"preview": "import os\nfrom contextlib import contextmanager\nimport warnings\nimport math\n\nimport torch\n\n# configuration for bitsandby"
},
{
"path": "lit_llama/tokenizer.py",
"chars": 1555,
"preview": "import os\nfrom pathlib import Path\nfrom typing import Optional\n\nimport torch\nfrom sentencepiece import SentencePieceProc"
},
{
"path": "lit_llama/utils.py",
"chars": 8645,
"preview": "\"\"\"Utility functions for training and inference.\"\"\"\n\nimport functools\nfrom pathlib import Path\nimport pickle\nimport warn"
},
{
"path": "models/encdec.py",
"chars": 2308,
"preview": "import torch.nn as nn\nfrom models.resnet import Resnet1D\n\nclass Encoder(nn.Module):\n def __init__(self,\n "
},
{
"path": "models/evaluator_wrapper.py",
"chars": 3756,
"preview": "\nimport torch\nfrom os.path import join as pjoin\nimport numpy as np\nfrom models.modules import MovementConvEncoder, TextE"
},
{
"path": "models/modules.py",
"chars": 3988,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.nn.utils.rnn import pack_padded_sequence\n\ndef init_weight(m):\n if isins"
},
{
"path": "models/quantize_cnn.py",
"chars": 13993,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass QuantizeEMAReset(nn.Module)"
},
{
"path": "models/resnet.py",
"chars": 2653,
"preview": "import torch.nn as nn\nimport torch\n\nclass nonlinearity(nn.Module):\n def __init__(self):\n super().__init__()\n\n "
},
{
"path": "models/rotation2xyz.py",
"chars": 3476,
"preview": "# This code is based on https://github.com/Mathux/ACTOR.git\nimport torch\nimport utils.rotation_conversions as geometry\n\n"
},
{
"path": "models/smpl.py",
"chars": 3858,
"preview": "# This code is based on https://github.com/Mathux/ACTOR.git\nimport numpy as np\nimport torch\n\nimport contextlib\n\nfrom smp"
},
{
"path": "models/vqvae.py",
"chars": 3825,
"preview": "# This code is based on https://github.com/Mael-zys/T2M-GPT.git\nimport torch.nn as nn\nfrom models.encdec import Encoder,"
},
{
"path": "options/get_eval_option.py",
"chars": 2533,
"preview": "from argparse import Namespace\nimport re\nfrom os.path import join as pjoin\n\n\ndef is_float(numStr):\n flag = False\n "
},
{
"path": "options/option.py",
"chars": 3874,
"preview": "import argparse\n\ndef get_args_parser():\n parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder "
},
{
"path": "options/option_vqvae.py",
"chars": 3399,
"preview": "import argparse\n\ndef get_args_parser():\n parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder "
},
{
"path": "prepare/download_evaluators.sh",
"chars": 381,
"preview": "mkdir -p checkpoints\ncd checkpoints\n\necho \"The evaluators will be stored in the './checkpoints' folder\"\necho \"Downloadin"
},
{
"path": "prepare/download_glove.sh",
"chars": 234,
"preview": "echo \"The glove will be stored in the './' folder\"\necho \"Downloading\"\ngdown \"https://drive.google.com/uc?id=1QMoWoaYIRA-"
},
{
"path": "prepare/download_lora.sh",
"chars": 275,
"preview": "mkdir -p checkpoints/pretrained_lora\ncd checkpoints/pretrained_lora\n\necho \"The pretrained model will be stored in the '."
},
{
"path": "prepare/download_smpl.sh",
"chars": 252,
"preview": "echo \"The body_models will be stored in the './' folder\"\necho \"Downloading\"\ngdown \"https://drive.google.com/uc?id=1uyGhO"
},
{
"path": "prepare/download_vqvae.sh",
"chars": 315,
"preview": "mkdir -p checkpoints\ncd checkpoints\n\necho \"The pretrained_vqvae will be stored in the './checkpoints' folder\"\necho \"Down"
},
{
"path": "scripts/convert_checkpoint.py",
"chars": 4999,
"preview": "import gc\nimport shutil\nfrom pathlib import Path\nfrom typing import Dict\n\nimport torch\nfrom tqdm import tqdm\n\n\"\"\"\nSample"
},
{
"path": "scripts/convert_hf_checkpoint.py",
"chars": 4652,
"preview": "import gc\nimport json\nimport shutil\nimport sys\nfrom pathlib import Path\n\nimport torch\n\n# support running without install"
},
{
"path": "scripts/download.py",
"chars": 1262,
"preview": "import os\nfrom typing import Optional\nfrom urllib.request import urlretrieve\n\nfiles = {\n \"original_model.py\": \"https:"
},
{
"path": "scripts/generate_dataset.py",
"chars": 2447,
"preview": "import numpy as np\nimport json, random\nfrom random import sample\nfrom tqdm import tqdm\nimport os\nimport sys\nfrom pathlib"
},
{
"path": "scripts/prepare_data.py",
"chars": 1469,
"preview": "import os\nimport sys\nfrom pathlib import Path\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n\nimpo"
},
{
"path": "scripts/prepare_motion.py",
"chars": 4056,
"preview": "\"\"\"Implementation derived from https://github.com/tloen/alpaca-lora\"\"\"\nimport os\nimport sys\nfrom pathlib import Path\n\n# "
},
{
"path": "sitemap.xml",
"chars": 715,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\r\n<urlset\r\n xmlns=\"http://www.sitemaps.org/schemas/sitemap/0.9\"\r\n xmlns:"
},
{
"path": "static/css/bulma.css.map.txt",
"chars": 96528,
"preview": "{\"version\":3,\"sources\":[\"../bulma.sass\",\"../sass/utilities/_all.sass\",\"../sass/utilities/animations.sass\",\"bulma.css\",\"."
},
{
"path": "static/css/index.css",
"chars": 2118,
"preview": "body {\n font-family: 'Noto Sans', sans-serif;\n}\n\n\n.footer .icon-link {\n font-size: 25px;\n color: #000;\n}\n\n.link-b"
},
{
"path": "static/js/bulma-carousel.js",
"chars": 82823,
"preview": "(function webpackUniversalModuleDefinition(root, factory) {\n\tif(typeof exports === 'object' && typeof module === 'object"
},
{
"path": "static/js/bulma-slider.js",
"chars": 16404,
"preview": "(function webpackUniversalModuleDefinition(root, factory) {\n\tif(typeof exports === 'object' && typeof module === 'object"
},
{
"path": "static/js/index.js",
"chars": 1485,
"preview": "$(document).ready(function() {\n // Check for click events on the navbar burger icon\n $(\".navbar-burger\").click(fun"
},
{
"path": "train_vqvae.py",
"chars": 6262,
"preview": "# This code is based on https://github.com/Mael-zys/T2M-GPT.git\nimport os\nimport json\n\nimport torch\nimport torch.optim a"
},
{
"path": "utils/config.py",
"chars": 446,
"preview": "import os\n\nSMPL_DATA_PATH = \"./body_models/smpl\"\n\nSMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, \"kintree_table.pkl\")\n"
},
{
"path": "utils/evaluate.py",
"chars": 14648,
"preview": "import os\nimport re\nimport numpy as np\nimport torch\nfrom scipy import linalg\nfrom tqdm import tqdm\n\nfrom generate_batch "
},
{
"path": "utils/losses.py",
"chars": 1045,
"preview": "import torch\nimport torch.nn as nn\n\nclass ReConsLoss(nn.Module):\n def __init__(self, recons_loss, nb_joints):\n "
},
{
"path": "utils/motion_process.py",
"chars": 2007,
"preview": "import torch\nfrom utils.quaternion import quaternion_to_cont6d, qrot, qinv\n\ndef recover_root_rot_pos(data):\n rot_vel "
},
{
"path": "utils/paramUtil.py",
"chars": 1857,
"preview": "import numpy as np\n\n# Define a kinematic tree for the skeletal struture\nkit_kinematic_chain = [[0, 11, 12, 13, 14, 15], "
},
{
"path": "utils/quaternion.py",
"chars": 12861,
"preview": "# Copyright (c) 2018-present, Facebook, Inc.\n# All rights reserved.\n#\n# This source code is licensed under the license f"
},
{
"path": "utils/rotation_conversions.py",
"chars": 18964,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.\n# Check PYTORCH3D_LICENCE before use\n\nimport fun"
},
{
"path": "utils/skeleton.py",
"chars": 8703,
"preview": "from utils.quaternion import *\nimport scipy.ndimage.filters as filters\n\nclass Skeleton(object):\n def __init__(self, o"
},
{
"path": "utils/utils_model.py",
"chars": 2108,
"preview": "import numpy as np \nimport torch\nimport torch.optim as optim\nimport logging\nimport os \nimport sys \n\ndef getCi(accLog):\n\n"
},
{
"path": "utils/word_vectorizer.py",
"chars": 3310,
"preview": "import numpy as np\nimport pickle\nfrom os.path import join as pjoin\n\nPOS_enumerator = {\n 'VERB': 0,\n 'NOUN': 1,\n "
},
{
"path": "visualization/plot_3d_global.py",
"chars": 5398,
"preview": "import torch \nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom glob import glob\nimport io\nimport matplotlib\nfrom "
},
{
"path": "visualization/render.py",
"chars": 5785,
"preview": "from models.rotation2xyz import Rotation2xyz\nimport numpy as np\nfrom trimesh import Trimesh\nimport os\nos.environ['PYOPEN"
},
{
"path": "visualize/joints2smpl/smpl_models/gmm_08.pkl",
"chars": 839127,
"preview": "(dp1\nS'covars'\np2\ncnumpy.core.multiarray\n_reconstruct\np3\n(cnumpy\nndarray\np4\n(I0\ntS'b'\ntRp5\n(I1\n(I8\nI69\nI69\ntcnumpy\ndtype"
},
{
"path": "visualize/joints2smpl/smpl_models/smplx_parts_segm.pkl",
"chars": 1323168,
"preview": "(dp0\nS'segm'\np1\ncnumpy.core.multiarray\n_reconstruct\np2\n(cnumpy\nndarray\np3\n(I0\ntp4\nS'b'\np5\ntp6\nRp7\n(I1\n(I20908\ntp8\ncnumpy"
},
{
"path": "visualize/joints2smpl/src/config.py",
"chars": 1295,
"preview": "import numpy as np\n\n# Map joints Name to SMPL joints idx\nJOINT_MAP = {\n'MidHip': 0,\n'LHip': 1, 'LKnee': 4, 'LAnkle': 7, "
},
{
"path": "visualize/joints2smpl/src/customloss.py",
"chars": 8935,
"preview": "import torch\nimport torch.nn.functional as F\nfrom visualize.joints2smpl.src import config\n\n# Guassian\ndef gmof(x, sigma)"
},
{
"path": "visualize/joints2smpl/src/prior.py",
"chars": 8684,
"preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
},
{
"path": "visualize/joints2smpl/src/smplify.py",
"chars": 12641,
"preview": "import torch\nimport os, sys\nimport pickle\nimport smplx\nimport numpy as np\n\nsys.path.append(os.path.dirname(__file__))\nfr"
},
{
"path": "visualize/render_mesh.py",
"chars": 1487,
"preview": "import argparse\nimport os\nfrom visualize import vis_utils\nimport shutil\nfrom tqdm import tqdm\n\nif __name__ == '__main__'"
},
{
"path": "visualize/simplify_loc2rot.py",
"chars": 5547,
"preview": "import numpy as np\nimport os\nimport torch\nfrom visualize.joints2smpl.src import config\nimport smplx\nimport h5py\nfrom vis"
},
{
"path": "visualize/vis_utils.py",
"chars": 3204,
"preview": "from models.rotation2xyz import Rotation2xyz\nimport numpy as np\nfrom trimesh import Trimesh\nimport os\nimport torch\nfrom "
}
]
// ... and 2 more files (download for full content)
About this extraction
This page contains the full source code of the qiqiApink/MotionGPT GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 72 files (2.6 MB), approximately 678.9k tokens, and a symbol index with 491 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.