Update weights and modeling code to latest version
Browse files- modeling_lucavirus.py +7 -2
modeling_lucavirus.py
CHANGED
|
@@ -1141,8 +1141,11 @@ class LucaVirusForMaskedLM(LucaVirusPreTrainedModel):
|
|
| 1141 |
|
| 1142 |
class LucaVirusForSequenceClassification(LucaVirusPreTrainedModel):
|
| 1143 |
def __init__(self, config):
|
|
|
|
|
|
|
| 1144 |
super().__init__(config)
|
| 1145 |
-
self.num_labels = config.
|
|
|
|
| 1146 |
self.task_level = config.task_level
|
| 1147 |
self.task_type = config.task_type
|
| 1148 |
assert self.task_level == "seq_level"
|
|
@@ -1247,8 +1250,10 @@ class LucaVirusForSequenceClassification(LucaVirusPreTrainedModel):
|
|
| 1247 |
|
| 1248 |
class LucaVirusForTokenClassification(LucaVirusPreTrainedModel):
|
| 1249 |
def __init__(self, config):
|
|
|
|
|
|
|
| 1250 |
super().__init__(config)
|
| 1251 |
-
self.num_labels = config.
|
| 1252 |
self.task_level = config.task_level
|
| 1253 |
self.task_type = config.task_type
|
| 1254 |
assert self.task_level == "token_level"
|
|
|
|
| 1141 |
|
| 1142 |
class LucaVirusForSequenceClassification(LucaVirusPreTrainedModel):
|
| 1143 |
def __init__(self, config):
|
| 1144 |
+
if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0:
|
| 1145 |
+
config.num_labels = config.classifier_num_labels
|
| 1146 |
super().__init__(config)
|
| 1147 |
+
self.num_labels = config.num_labels
|
| 1148 |
+
self.num_labels = config.num_labels
|
| 1149 |
self.task_level = config.task_level
|
| 1150 |
self.task_type = config.task_type
|
| 1151 |
assert self.task_level == "seq_level"
|
|
|
|
| 1250 |
|
| 1251 |
class LucaVirusForTokenClassification(LucaVirusPreTrainedModel):
|
| 1252 |
def __init__(self, config):
|
| 1253 |
+
if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0:
|
| 1254 |
+
config.num_labels = config.classifier_num_labels
|
| 1255 |
super().__init__(config)
|
| 1256 |
+
self.num_labels = config.num_labels
|
| 1257 |
self.task_level = config.task_level
|
| 1258 |
self.task_type = config.task_type
|
| 1259 |
assert self.task_level == "token_level"
|