transformer 转ONNX

transformer转onnx,直接采用pytorch的原生接口,不依赖huggingface的工具,会更方便,代码如下

dummy_input = {
                    "input_ids": torch.tensor([[101, 2769, 1372, 2682, 2127, 102, 0]]),
                    "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 0]]),
                }
                dynamic_axes = {
                    'input_ids': [0, 1],
                    'attention_mask': [0, 1],
                    'output': [0]
                }
torch.onnx.export(self.model,               # model being run
                                  tuple(dummy_input.values()),
                                  out_name,
                                  export_params=True,        # store the trained parameter weights inside the model file
                                  opset_version=11,          # the ONNX version to export the model to
                                  do_constant_folding=True,  # whether to execute constant folding for optimization
                                  # the model's input names
                                  input_names=list(dummy_input.keys()),
                                  # the model's output names
                                  output_names=['output'],
                                  dynamic_axes=dynamic_axes
                                  )
Comments
登录后评论
Sign In