pymlex commited on
Commit
d13015d
·
verified ·
1 Parent(s): db32d32

Upload spanish_cefr_bertin.ipynb

Browse files
Files changed (1) hide show
  1. spanish_cefr_bertin.ipynb +206 -0
spanish_cefr_bertin.ipynb CHANGED
@@ -1401,6 +1401,212 @@
1401
  "trainer.push_to_hub(commit_message=\"Spanish CEFR fine-tuning\")\n",
1402
  "print(repo_id)"
1403
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1404
  }
1405
  ],
1406
  "metadata": {
 
1401
  "trainer.push_to_hub(commit_message=\"Spanish CEFR fine-tuning\")\n",
1402
  "print(repo_id)"
1403
  ]
1404
+ },
1405
+ {
1406
+ "cell_type": "markdown",
1407
+ "source": [
1408
+ "## Inference"
1409
+ ],
1410
+ "metadata": {
1411
+ "id": "BDGpYPNLZXba"
1412
+ },
1413
+ "id": "BDGpYPNLZXba"
1414
+ },
1415
+ {
1416
+ "cell_type": "code",
1417
+ "source": [
1418
+ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
1419
+ "import torch\n",
1420
+ "\n",
1421
+ "model_id = \"pymlex/roberta-spanish-cefr\"\n",
1422
+ "\n",
1423
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
1424
+ "model = AutoModelForSequenceClassification.from_pretrained(model_id)\n",
1425
+ "model.eval()\n",
1426
+ "\n",
1427
+ "def predict_cefr(text, top_k=3):\n",
1428
+ " inputs = tokenizer(\n",
1429
+ " text,\n",
1430
+ " return_tensors=\"pt\",\n",
1431
+ " truncation=True,\n",
1432
+ " max_length=512,\n",
1433
+ " )\n",
1434
+ " with torch.no_grad():\n",
1435
+ " logits = model(**inputs).logits\n",
1436
+ " probs = torch.softmax(logits, dim=-1)[0]\n",
1437
+ "\n",
1438
+ " k = min(top_k, probs.numel())\n",
1439
+ " values, indices = torch.topk(probs, k=k)\n",
1440
+ "\n",
1441
+ " return [\n",
1442
+ " {\n",
1443
+ " \"label\": model.config.id2label[i.item()],\n",
1444
+ " \"score\": float(v.item()),\n",
1445
+ " }\n",
1446
+ " for i, v in zip(indices, values)\n",
1447
+ " ]\n",
1448
+ "\n",
1449
+ "text = \"Estimados se\u00f1ores, les escribo para solicitar informaci\u00f3n sobre el curso.\"\n",
1450
+ "print(predict_cefr(text, top_k=3))"
1451
+ ],
1452
+ "metadata": {
1453
+ "colab": {
1454
+ "base_uri": "https://localhost:8080/",
1455
+ "height": 318,
1456
+ "referenced_widgets": [
1457
+ "81146c5153fc4c599ffa4210f49315f9",
1458
+ "30b7c08c82e642e68c324fa37d2273e1",
1459
+ "c6df9d43f7f944e3a731aa794561928e",
1460
+ "3eb95943e51849309a17ddca78155da8",
1461
+ "0b0cd91951cc432f8cb3a3e5399914d8",
1462
+ "fcb1dad83c224c068787027e8bcfd398",
1463
+ "d9297a6d9f174cc99c0c407f5196d2ec",
1464
+ "81fb07e5833548f695df1d4421253990",
1465
+ "706c0a5d14e34471b28e2b98dcc9126f",
1466
+ "cadf5dd05b614d5b962ea953ac9da959",
1467
+ "c68665f850b74a2bb710d3de0ba92087",
1468
+ "3f00dc74734a498ca37b21b5d514d576",
1469
+ "d82869ad687a43eabb113ce699cbf369",
1470
+ "814d90db9746490a825f092523c3a4b9",
1471
+ "c9b0001778274587a251eddd7ecc342f",
1472
+ "939f139ca1054dffa12b1d0a46e1aee4",
1473
+ "4845ee644faf400b94c89ff9d036efd0",
1474
+ "84aa6f541feb4a848b1ee256a793f56a",
1475
+ "4a2b76e52f4046bc84da2fcafdba6bf1",
1476
+ "41c4b4dbf9684298b13e25ad627713bf",
1477
+ "44eff00a96c44e34bf5c8864b11c7dbf",
1478
+ "8e8c0db7bc054c1f8aaccd331cbe742f",
1479
+ "380e59a0af0044ee9cc6cb352270f2b5",
1480
+ "ccd972537eff4c48ad7a69741a905d91",
1481
+ "4b59346d2dfe488a974d1d5c8986be68",
1482
+ "9a54680eb4b6453fb6fe09812ce3438d",
1483
+ "e9ca9c53899441e393db33551a35e354",
1484
+ "d7e32c7d91b94b2997b1c2e8ad5ee9a1",
1485
+ "19809cbf199d423483b1608d87b9ac3d",
1486
+ "a717856a82334c2c993cdada3f904de8",
1487
+ "82e8ba74e4ce4ffc8d6e559c12733dc0",
1488
+ "7973daaba8914086b9bc0d67181a8034",
1489
+ "4f6da9c2fc924d83b41092b7bb5643a3",
1490
+ "6f2f6435bc904bb294955df6db6f2d5b",
1491
+ "b23cb1da8901414c8bee3a65bd352c41",
1492
+ "3d33aaff6edc44ea9d5a4db026832f2b",
1493
+ "edc3cb71239e4f7488eaa361ccf34243",
1494
+ "9f39495c04434441b3137a9b36f038a9",
1495
+ "c02c08a573c34adcbbc3d9fffea1b0e6",
1496
+ "8df14bad8d9c4a409df3a57b6776cc50",
1497
+ "5f87bb2c926941fba9dee6fad84dd0d1",
1498
+ "771cdc1db79048ef8cff34bfa7783b0f",
1499
+ "4f06b9da51ac4881a8626fac9f6525a8",
1500
+ "00186b8e015041efbd04588b93d3ea1c",
1501
+ "cd786a95965542309c2e06a30b1ecbc8",
1502
+ "1fbf3b722a624c45857c6eb03cae065b",
1503
+ "2790d0484be342a790239edd60eaa937",
1504
+ "63b8697c268d4d38846bbcba1e9a7b0e",
1505
+ "efd3f3f2b4cb4712a4f7125d45d5c721",
1506
+ "fb49075f8b1d4ea3913ee44b3391ff5b",
1507
+ "2aff4da8c86a4b9c8836579e00501ea6",
1508
+ "583cc33276924a55bf059969fc19b2cb",
1509
+ "e6969631cc954f79b2b8e5cd84546323",
1510
+ "3c9ad6fc90624e1483150e1df7ec71fd",
1511
+ "0b49180ffb1d4e41b6c2e3990cfd6507"
1512
+ ]
1513
+ },
1514
+ "id": "Xkw2MLNdZqct",
1515
+ "outputId": "2cf27c73-5ec5-438e-b23c-6ea35188900a"
1516
+ },
1517
+ "id": "Xkw2MLNdZqct",
1518
+ "execution_count": 1,
1519
+ "outputs": [
1520
+ {
1521
+ "output_type": "stream",
1522
+ "name": "stderr",
1523
+ "text": [
1524
+ "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
1525
+ "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
1526
+ "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
1527
+ "You will be able to reuse this secret in all of your notebooks.\n",
1528
+ "Please note that authentication is recommended but still optional to access public models or datasets.\n",
1529
+ " warnings.warn(\n"
1530
+ ]
1531
+ },
1532
+ {
1533
+ "output_type": "display_data",
1534
+ "data": {
1535
+ "text/plain": [
1536
+ "config.json: 0.00B [00:00, ?B/s]"
1537
+ ],
1538
+ "application/vnd.jupyter.widget-view+json": {
1539
+ "version_major": 2,
1540
+ "version_minor": 0,
1541
+ "model_id": "81146c5153fc4c599ffa4210f49315f9"
1542
+ }
1543
+ },
1544
+ "metadata": {}
1545
+ },
1546
+ {
1547
+ "output_type": "display_data",
1548
+ "data": {
1549
+ "text/plain": [
1550
+ "tokenizer_config.json: 0%| | 0.00/377 [00:00<?, ?B/s]"
1551
+ ],
1552
+ "application/vnd.jupyter.widget-view+json": {
1553
+ "version_major": 2,
1554
+ "version_minor": 0,
1555
+ "model_id": "3f00dc74734a498ca37b21b5d514d576"
1556
+ }
1557
+ },
1558
+ "metadata": {}
1559
+ },
1560
+ {
1561
+ "output_type": "display_data",
1562
+ "data": {
1563
+ "text/plain": [
1564
+ "tokenizer.json: 0.00B [00:00, ?B/s]"
1565
+ ],
1566
+ "application/vnd.jupyter.widget-view+json": {
1567
+ "version_major": 2,
1568
+ "version_minor": 0,
1569
+ "model_id": "380e59a0af0044ee9cc6cb352270f2b5"
1570
+ }
1571
+ },
1572
+ "metadata": {}
1573
+ },
1574
+ {
1575
+ "output_type": "display_data",
1576
+ "data": {
1577
+ "text/plain": [
1578
+ "model.safetensors: 0%| | 0.00/499M [00:00<?, ?B/s]"
1579
+ ],
1580
+ "application/vnd.jupyter.widget-view+json": {
1581
+ "version_major": 2,
1582
+ "version_minor": 0,
1583
+ "model_id": "6f2f6435bc904bb294955df6db6f2d5b"
1584
+ }
1585
+ },
1586
+ "metadata": {}
1587
+ },
1588
+ {
1589
+ "output_type": "display_data",
1590
+ "data": {
1591
+ "text/plain": [
1592
+ "Loading weights: 0%| | 0/201 [00:00<?, ?it/s]"
1593
+ ],
1594
+ "application/vnd.jupyter.widget-view+json": {
1595
+ "version_major": 2,
1596
+ "version_minor": 0,
1597
+ "model_id": "cd786a95965542309c2e06a30b1ecbc8"
1598
+ }
1599
+ },
1600
+ "metadata": {}
1601
+ },
1602
+ {
1603
+ "output_type": "stream",
1604
+ "name": "stdout",
1605
+ "text": [
1606
+ "[{'label': 'A1', 'score': 0.22886891663074493}, {'label': 'B1', 'score': 0.19498008489608765}, {'label': 'A2', 'score': 0.19106613099575043}]\n"
1607
+ ]
1608
+ }
1609
+ ]
1610
  }
1611
  ],
1612
  "metadata": {