LucaGroup commited on
Commit
5103f5f
·
verified ·
1 Parent(s): dbafc7b

Update weights and modeling code to latest version

Browse files
Files changed (1) hide show
  1. modeling_lucavirus.py +7 -2
modeling_lucavirus.py CHANGED
@@ -1141,8 +1141,11 @@ class LucaVirusForMaskedLM(LucaVirusPreTrainedModel):
1141
 
1142
  class LucaVirusForSequenceClassification(LucaVirusPreTrainedModel):
1143
  def __init__(self, config):
 
 
1144
  super().__init__(config)
1145
- self.num_labels = config.classifier_num_labels
 
1146
  self.task_level = config.task_level
1147
  self.task_type = config.task_type
1148
  assert self.task_level == "seq_level"
@@ -1247,8 +1250,10 @@ class LucaVirusForSequenceClassification(LucaVirusPreTrainedModel):
1247
 
1248
  class LucaVirusForTokenClassification(LucaVirusPreTrainedModel):
1249
  def __init__(self, config):
 
 
1250
  super().__init__(config)
1251
- self.num_labels = config.classifier_num_labels
1252
  self.task_level = config.task_level
1253
  self.task_type = config.task_type
1254
  assert self.task_level == "token_level"
 
1141
 
1142
  class LucaVirusForSequenceClassification(LucaVirusPreTrainedModel):
1143
  def __init__(self, config):
1144
+ if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0:
1145
+ config.num_labels = config.classifier_num_labels
1146
  super().__init__(config)
1147
+ self.num_labels = config.num_labels
1148
+ self.num_labels = config.num_labels
1149
  self.task_level = config.task_level
1150
  self.task_type = config.task_type
1151
  assert self.task_level == "seq_level"
 
1250
 
1251
  class LucaVirusForTokenClassification(LucaVirusPreTrainedModel):
1252
  def __init__(self, config):
1253
+ if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0:
1254
+ config.num_labels = config.classifier_num_labels
1255
  super().__init__(config)
1256
+ self.num_labels = config.num_labels
1257
  self.task_level = config.task_level
1258
  self.task_type = config.task_type
1259
  assert self.task_level == "token_level"