import torch, torch.nn as nn x = torch.randint(1, 10, (100, 5)) y = torch.flip(x, [1]) model = nn.Transformer(d_model=16, nhead=2, num_encoder_layers=1, num_decoder_layers=1, batch_first=True) emb = nn.Embedding(10, 16) fc = nn.Linear(16, 10) opt = torch.optim.Adam(list(model.parameters()) + list(emb.parameters()) + list(fc.parameters()), lr=0.01) for _ in range(200): src = emb(x) tgt = emb(y) out = model(src, tgt) loss = nn.CrossEntropyLoss()(fc(out).reshape(-1,10), y.reshape(-1)) opt.zero_grad(); loss.backward(); opt.step() t = torch.tensor([[1,2,3,4,5]]) pred = fc(model(emb(t), emb(torch.flip(t,[1])))).argmax(-1) print("Input :", t.tolist()) print("Target:", torch.flip(t,[1]).tolist()) print("Pred :", pred.tolist())