|
|
@ -41,7 +41,8 @@ class Text2SemanticDataModule(LightningDataModule):
|
|
|
|
# pad_val=self.config['data']['pad_val'])
|
|
|
|
# pad_val=self.config['data']['pad_val'])
|
|
|
|
|
|
|
|
|
|
|
|
def train_dataloader(self):
|
|
|
|
def train_dataloader(self):
|
|
|
|
batch_size = max(min(self.config["train"]["batch_size"],len(self._train_dataset)//4),1)#防止不保存
|
|
|
|
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
|
|
|
|
|
|
|
|
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
|
|
|
|
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
|
|
|
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
|
|
|
return DataLoader(
|
|
|
|
return DataLoader(
|
|
|
|
self._train_dataset,
|
|
|
|
self._train_dataset,
|
|
|
|