Data
base_data_module
Base DataModule class.
- class unKR.data.base_data_module.BaseDataModule(*args: Any, **kwargs: Any)[source]
Bases:
LightningDataModule
Base DataModule. Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html
- prepare_data()[source]
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don’t set state self.x = y).
- setup(stage=None)[source]
Split into train, val, test, and set dims. Should assign torch Dataset objects to self.data_train, self.data_val, and optionally self.data_test.
- test_dataloader()[source]
Implement one or multiple PyTorch DataLoaders for testing.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a postive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoader
or a sequence of them specifying testing samples.
Example:
def test_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def test_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.Note
In the case where you return multiple test dataloaders, the
test_step()
will have an argumentdataloader_idx
which matches the order here.
- train_dataloader()[source]
Implement one or more PyTorch DataLoaders for training.
- Returns:
A collection of
torch.utils.data.DataLoader
specifying training samples. In the case of multiple dataloaders, please see this page.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
…
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example:
# single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True ) return loader # multiple dataloaders, return as list def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a list of tensors: [batch_mnist, batch_cifar] return [mnist_loader, cifar_loader] # multiple dataloader, return as dict def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader}
- val_dataloader()[source]
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoader
or a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataloader_idx
which matches the order here.
DataPreprocess
- class unKR.data.DataPreprocess.GMUCBaseSampler(args)[source]
Bases:
GMUCData
Traditional GMUC random sampling mode.
- generate_false(query_triples, candidates)[source]
Generate false triples.
- Parameters:
query_triples – Query set.
candidates – All entities that belong to the relationship.
- Returns:
False triples.(confidence = 0) false_left: The head entity of false triples. false_right: The tail entity of false triples.
- Return type:
false_pairs
- class unKR.data.DataPreprocess.GMUCData(args)[source]
Bases:
object
Data processing for few-shot GMUC & GMUC+ data.
- args
Some pre-set parameters, such as dataset path, etc.
- ent2id
Encoding the entity in triples, type: dict.
- rel2id
Encoding the relation in triples, type: dict.
- symbol2id
Encoding the entity and relation in triples, type: dict.
- e1rel_e2
Record the tail corresponding to the same head and relation, type: defaultdict(class:list).
- rele1_e2
Record the tail corresponding to the same relation and head, type: defaultdict(class:dict).
- train_tasks
Record the triples for training, type: dict.
- dev_tasks
Record the triples for validation, type: dict.
- test_tasks
Record the triples for testing, type: dict.
- task_pool
A task is a relation, type: list.
- path_graph
Background triples, type: list.
- type2ents
All entities of the same type, type: defaultdict(class:set).
- known_rels
The triples of path_graph and train_tasks, type: defaultdict(class:list).
- num_tasks
The number of tasks, type: int.
- rel2candidates
Record the entities corresponding to the same relation, type: list.
- ent2ic
Calculate IIC for every entity, type: dict.
- rel_uc1
Calculate IIC for every relation, type: dict.
- rel_uc2
Calculate IIC for every relation, type: dict.
- connections
Neighbor information for each entity, type: numpy.
- e1_rele2
Record the relation and tail corresponding to the same head, type: defaultdict(class:list).
- e1_degrees
Record the number of neighbors per entity, type: defaultdict(class:int).
- build_graph(max_=50)[source]
Build the graph according to path_graph.
- Update:
self.connections: The set of connections. self.e1_rele2: The set of e1_rele2. self.e1_degrees: The set of e1_degrees.
- get_e1rel_e2()[source]
Get the set of e1rel_e2 from all dataset.
- Update:
self.e1rel_e2: The set of e1rel_e2.
- get_ontology()[source]
Get the IIC of the entity and UC of the relation.
IIC(c) = 1-
rac{log(hypo(c)+1)}{log(n)}
- get_rel2candidates()[source]
Get the set of rel2candidates. A candidate is an entity under a relationship.
- Update:
self.known_rels: The set of known_rels. self.type2ents: The set of type2ents. self.rel2candidates: Obtain 1000 entities of the same type as a candidate set.
- class unKR.data.DataPreprocess.UKGData(args)[source]
Bases:
object
Data preprocessing of ukg data.
- args
Some pre-set parameters, such as dataset path, etc.
- ent2id
Encoding the entity in triples, type: dict.
- rel2id
Encoding the relation in triples, type: dict.
- id2ent
Decoding the entity in triples, type: dict.
- id2rel
Decoding the realtion in triples, type: dict.
- train_triples
Record the triples for training, type: list.
- valid_triples
Record the triples for validation, type: list.
- test_triples
Record the triples for testing, type: list.
- PSL_triples
Record the triples for softlogic, type: list. (will be used in UKGE_PSL)
- pseudo_triples
Record the triples for pseudo, type: list. (will be used in UPGAT)
- all_true_triples
Record all triples including train,valid and test, type: list.
- hr2t_train
Record the tail corresponding to the same head and relation, type: defaultdict(class:set).
- rt2h_train
Record the head corresponding to the same tail and relation, type: defaultdict(class:set).
- h2rt_train
Record the tail, relation corresponding to the same head, type: defaultdict(class:set).
- t2rh_train
Record the head, realtion corresponding to the same tail, type: defaultdict(class:set).
- hr2t_total
Record the tail corresponding to the same head and relation in whole dataset(train + val + test), type: defaultdict(class:set).
- rt2h_total
Record the head corresponding to the same tail and relation in whole dataset(train + val + test), type: defaultdict(class:set).
- RatioOfPSL
Record the ratio of the number of PSL samples to the number of training samples. (will be used in UKGE_PSL)
- pseudo_dataiter
Record the data in dataloader. (will be used in UPGAT)
- static count_frequency(triples, start=4)[source]
Get frequency of a partial triple like (head, relation) or (relation, tail).
The frequency will be used for subsampling like word2vec.
- Parameters:
triples – Sampled triples.
start – Initial count number.
- Returns:
Record the number of (head, relation).
- Return type:
count
- get_h2rt_t2hr_from_train()[source]
Get the set of h2rt and t2hr from train dataset, the data type is numpy.
- Update:
self.h2rt_train: The set of h2rt. self.t2rh_train: The set of t2hr.
- get_hr2t_rt2h_from_total()[source]
Get the set of hr2t and rt2h from whole dataset(train+val+test), the data type is numpy.
- Update:
self.hr2t_total: The set of hr2t in whole dataset. self.rt2h_total: The set of rt2h in whole dataset.
- get_hr2t_rt2h_from_train()[source]
Get the set of hr2t and rt2h from train dataset, the data type is numpy.
- Update:
self.hr2t_train: The set of hr2t. self.rt2h_train: The set of rt2h.
- class unKR.data.DataPreprocess.UKGEBaseSampler(args)[source]
Bases:
UKGData
data processing for UKG data, the sampling method is consistent with NeuralKG (https://github.com/zjukg/NeuralKG)
- corrupt_head(t, r, num_max=1)[source]
Negative sampling of head entities.
- Parameters:
t – Tail entity in triple.
r – Relation in triple.
num_max – The maximum of negative samples generated
- Returns:
The negative sample of head entity filtering out the positive head entity.
- Return type:
neg
- corrupt_tail(h, r, num_max=1)[source]
Negative sampling of tail entities.
- Parameters:
h – Head entity in triple.
r – Relation in triple.
num_max – The maximum of negative samples generated
- Returns:
The negative sample of tail entity filtering out the positive tail entity.
- Return type:
neg
Sampler
- class unKR.data.Sampler.GMUCSampler(args)[source]
Bases:
GMUCBaseSampler
GMUC sampling Process task-based data.
- class unKR.data.Sampler.GMUCTestSampler(sampler)[source]
Bases:
object
Sampling triples and recording positive triples for testing.
- sampler
The function of training sampler.
- num_ent
The number of entities.
- rel2candidates
Record the entities corresponding to the same relation, type: list.
- symbol2id
Encoding the entity and relation in triples, type: dict.
- ent2id
Encoding the entity in triples, type: dict.
- class unKR.data.Sampler.UKGEPSLSampler(args)[source]
Bases:
UKGEBaseSampler
Random negative sampling Filtering out positive samples and selecting some samples randomly as negative samples. UKGEPSLSampler is for UKGE_PSL
- cross_sampling_flag
The flag of cross sampling head and tail negative samples.
- class unKR.data.Sampler.UKGETestSampler(sampler)[source]
Bases:
object
Sampling triples and recording positive triples for testing. We offer two test sample methods: one is to use the same method as neuralkg to process the test set, and the other is to only use high confidence samples from the test set for testing
- sampler
The function of training sampler.
- hr2t_all
Record the tail corresponding to the same head and relation.
- rt2h_all
Record the head corresponding to the same tail and relation.
- hr2t_all_high_score
Record the tail corresponding to the same head and relation (only for high score samples in test).
- rt2h_all_high_score
Record the head corresponding to the same tail and relation (only for high score samples in test).
- num_ent
The count of entities.
- get_hr2t_rt2h_from_all()[source]
Get the set of hr2t and rt2h from all datasets(train, valid, and test), the data type is tensor. Update:
self.hr2t_all: The set of hr2t. self.rt2h_all: The set of rt2h. self.hr2t_all_high_score: The set of hr2t (only for high score samples in test). self.rt2h_all_high_score: The set of rt2h (only for high score samples in test).
- class unKR.data.Sampler.UKGEUniSampler(args)[source]
Bases:
UKGEBaseSampler
Random negative sampling Filtering out positive samples and selecting some samples randomly as negative samples.
- cross_sampling_flag
The flag of cross sampling head and tail negative samples.
- unKR.data.Sampler.normal(loc=0.0, scale=1.0, size=None)
Draw random samples from a normal (Gaussian) distribution.
The probability density function of the normal distribution, first derived by De Moivre and 200 years later by both Gauss and Laplace independently [2], is often called the bell curve because of its characteristic shape (see the example below).
The normal distributions occurs often in nature. For example, it describes the commonly occurring distribution of samples influenced by a large number of tiny, random disturbances, each with its own unique distribution [2].
Note
New code should use the ~numpy.random.Generator.normal method of a ~numpy.random.Generator instance instead; please see the random-quick-start.
- Parameters:
loc (float or array_like of floats) – Mean (“centre”) of the distribution.
scale (float or array_like of floats) – Standard deviation (spread or “width”) of the distribution. Must be non-negative.
size (int or tuple of ints, optional) – Output shape. If the given shape is, e.g.,
(m, n, k)
, thenm * n * k
samples are drawn. If size isNone
(default), a single value is returned ifloc
andscale
are both scalars. Otherwise,np.broadcast(loc, scale).size
samples are drawn.
- Returns:
out – Drawn samples from the parameterized normal distribution.
- Return type:
ndarray or scalar
See also
scipy.stats.norm
probability density function, distribution or cumulative density function, etc.
random.Generator.normal
which should be used for new code.
Notes
The probability density for the Gaussian distribution is
\[p(x) = \frac{1}{\sqrt{ 2 \pi \sigma^2 }} e^{ - \frac{ (x - \mu)^2 } {2 \sigma^2} },\]where \(\mu\) is the mean and \(\sigma\) the standard deviation. The square of the standard deviation, \(\sigma^2\), is called the variance.
The function has its peak at the mean, and its “spread” increases with the standard deviation (the function reaches 0.607 times its maximum at \(x + \sigma\) and \(x - \sigma\) [2]). This implies that normal is more likely to return samples lying close to the mean, rather than those far away.
References
Examples
Draw samples from the distribution:
>>> mu, sigma = 0, 0.1 # mean and standard deviation >>> s = np.random.normal(mu, sigma, 1000)
Verify the mean and the variance:
>>> abs(mu - np.mean(s)) 0.0 # may vary
>>> abs(sigma - np.std(s, ddof=1)) 0.1 # may vary
Display the histogram of the samples, along with the probability density function:
>>> import matplotlib.pyplot as plt >>> count, bins, ignored = plt.hist(s, 30, density=True) >>> plt.plot(bins, 1/(sigma * np.sqrt(2 * np.pi)) * ... np.exp( - (bins - mu)**2 / (2 * sigma**2) ), ... linewidth=2, color='r') >>> plt.show()
Two-by-four array of samples from the normal distribution with mean 3 and standard deviation 2.5:
>>> np.random.normal(3, 2.5, size=(2, 4)) array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
KGDataModule
Base DataModule class.
- class unKR.data.KGDataModule.GMUCDataModule(*args: Any, **kwargs: Any)[source]
Bases:
BaseDataModule
GMUC Base DataModule. Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html
- get_data_config()[source]
Return important settings of the dataset, which will be passed to instantiate models.
- prepare_data()[source]
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don’t set state self.x = y).
- setup(stage=None)[source]
Split into train, val, test, and set dims. Should assign torch Dataset objects to self.data_train, self.data_val, and optionally self.data_test.
- test_dataloader()[source]
Implement one or multiple PyTorch DataLoaders for testing.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a postive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoader
or a sequence of them specifying testing samples.
Example:
def test_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def test_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.Note
In the case where you return multiple test dataloaders, the
test_step()
will have an argumentdataloader_idx
which matches the order here.
- train_dataloader()[source]
Implement one or more PyTorch DataLoaders for training.
- Returns:
A collection of
torch.utils.data.DataLoader
specifying training samples. In the case of multiple dataloaders, please see this page.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
…
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example:
# single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True ) return loader # multiple dataloaders, return as list def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a list of tensors: [batch_mnist, batch_cifar] return [mnist_loader, cifar_loader] # multiple dataloader, return as dict def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader}
- val_dataloader()[source]
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoader
or a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataloader_idx
which matches the order here.
- class unKR.data.KGDataModule.KGDataModule(*args: Any, **kwargs: Any)[source]
Bases:
BaseDataModule
Base DataModule. Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html
- get_data_config()[source]
Return important settings of the dataset, which will be passed to instantiate models.
- get_train_bs()[source]
Get batch size for training.
If the num_batches isn`t zero, it will divide data_train by num_batches to get batch size. And if user don`t give batch size and num_batches=0, it will raise ValueError.
- Returns:
The batch size for training.
- Return type:
self.args.train_bs
- prepare_data()[source]
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don’t set state self.x = y).
- setup(stage=None)[source]
Split into train, val, test, and set dims. Should assign torch Dataset objects to self.data_train, self.data_val, and optionally self.data_test.
- test_dataloader()[source]
Implement one or multiple PyTorch DataLoaders for testing.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a postive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoader
or a sequence of them specifying testing samples.
Example:
def test_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def test_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.Note
In the case where you return multiple test dataloaders, the
test_step()
will have an argumentdataloader_idx
which matches the order here.
- train_dataloader()[source]
Implement one or more PyTorch DataLoaders for training.
- Returns:
A collection of
torch.utils.data.DataLoader
specifying training samples. In the case of multiple dataloaders, please see this page.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
…
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example:
# single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True ) return loader # multiple dataloaders, return as list def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a list of tensors: [batch_mnist, batch_cifar] return [mnist_loader, cifar_loader] # multiple dataloader, return as dict def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader}
- val_dataloader()[source]
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoader
or a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataloader_idx
which matches the order here.