| import os | |
| import tensorflow as tf | |
| import torch | |
| from collections import OrderedDict | |
| tf_checkpoint_path = "chinese_GAU-alpha-char_L-24_H-768-tf/bert_model.ckpt" | |
| tf_path = os.path.abspath(tf_checkpoint_path) | |
| init_vars = tf.train.list_variables(tf_path) | |
| arrays = [] | |
| pytorch_state_dict = OrderedDict() | |
| for name, shape in init_vars: | |
| array = tf.train.load_variable(tf_path, name) | |
| new_name = ( | |
| name.replace("GAU_alpha", "gau_alpha") | |
| .replace("bert", "gau_alpha") | |
| .replace("/", ".") | |
| .replace("layer_", "layer.") | |
| .replace("kernel", "weight") | |
| .replace("gamma", "weight") | |
| ) | |
| if "embeddings" in new_name: | |
| new_name = new_name + ".weight" | |
| if "_dense" in new_name: | |
| array = array.T | |
| pytorch_state_dict[new_name] = torch.from_numpy(array) | |
| torch.save(pytorch_state_dict, "pytorch_model.bin") | |