Data

base_data_module

Base DataModule class.

class unKR.data.base_data_module.BaseDataModule(*args: Any, **kwargs: Any)[source]

Bases: LightningDataModule

Base DataModule. Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html

static add_to_argparse(parser)[source]
get_config()[source]
prepare_data()[source]

Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don’t set state self.x = y).

setup(stage=None)[source]

Split into train, val, test, and set dims. Should assign torch Dataset objects to self.data_train, self.data_val, and optionally self.data_test.

test_dataloader()[source]

Implement one or multiple PyTorch DataLoaders for testing.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a postive integer.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying testing samples.

Example:

def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def test_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Note

In the case where you return multiple test dataloaders, the test_step() will have an argument dataloader_idx which matches the order here.

train_dataloader()[source]

Implement one or more PyTorch DataLoaders for training.

Returns:

A collection of torch.utils.data.DataLoader specifying training samples. In the case of multiple dataloaders, please see this page.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example:

# single dataloader
def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=True
    )
    return loader

# multiple dataloaders, return as list
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a list of tensors: [batch_mnist, batch_cifar]
    return [mnist_loader, cifar_loader]

# multiple dataloader, return as dict
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
    return {'mnist': mnist_loader, 'cifar': cifar_loader}
val_dataloader()[source]

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying validation samples.

Examples:

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Note

In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.

class unKR.data.base_data_module.Config[source]

Bases: dict

DataPreprocess

class unKR.data.DataPreprocess.GMUCBaseSampler(args)[source]

Bases: GMUCData

Traditional GMUC random sampling mode.

generate_false(query_triples, candidates)[source]

Generate false triples.

Parameters:
  • query_triples – Query set.

  • candidates – All entities that belong to the relationship.

Returns:

False triples.(confidence = 0) false_left: The head entity of false triples. false_right: The tail entity of false triples.

Return type:

false_pairs

get_test()[source]
get_train()[source]
get_valid()[source]
class unKR.data.DataPreprocess.GMUCData(args)[source]

Bases: object

Data processing for few-shot GMUC & GMUC+ data.

args

Some pre-set parameters, such as dataset path, etc.

ent2id

Encoding the entity in triples, type: dict.

rel2id

Encoding the relation in triples, type: dict.

symbol2id

Encoding the entity and relation in triples, type: dict.

e1rel_e2

Record the tail corresponding to the same head and relation, type: defaultdict(class:list).

rele1_e2

Record the tail corresponding to the same relation and head, type: defaultdict(class:dict).

train_tasks

Record the triples for training, type: dict.

dev_tasks

Record the triples for validation, type: dict.

test_tasks

Record the triples for testing, type: dict.

task_pool

A task is a relation, type: list.

path_graph

Background triples, type: list.

type2ents

All entities of the same type, type: defaultdict(class:set).

known_rels

The triples of path_graph and train_tasks, type: defaultdict(class:list).

num_tasks

The number of tasks, type: int.

rel2candidates

Record the entities corresponding to the same relation, type: list.

ent2ic

Calculate IIC for every entity, type: dict.

rel_uc1

Calculate IIC for every relation, type: dict.

rel_uc2

Calculate IIC for every relation, type: dict.

connections

Neighbor information for each entity, type: numpy.

e1_rele2

Record the relation and tail corresponding to the same head, type: defaultdict(class:list).

e1_degrees

Record the number of neighbors per entity, type: defaultdict(class:int).

build_graph(max_=50)[source]

Build the graph according to path_graph.

Update:

self.connections: The set of connections. self.e1_rele2: The set of e1_rele2. self.e1_degrees: The set of e1_degrees.

get_e1rel_e2()[source]

Get the set of e1rel_e2 from all dataset.

Update:

self.e1rel_e2: The set of e1rel_e2.

get_ontology()[source]

Get the IIC of the entity and UC of the relation.

IIC(c) = 1-

rac{log(hypo(c)+1)}{log(n)}

UC_r1(r)=sum_{h in D_r, t in R_r} (UC_e(h) + UC_e(t)) UC_r2(r) = |D_r| * |R_r|

Update:

self.ent2ic: The set of rele1_e2. self.rel_uc1: The set of rele1_e2. self.rel_uc2: The set of rele1_e2.

get_rel2candidates()[source]

Get the set of rel2candidates. A candidate is an entity under a relationship.

Update:

self.known_rels: The set of known_rels. self.type2ents: The set of type2ents. self.rel2candidates: Obtain 1000 entities of the same type as a candidate set.

get_rele1_e2()[source]

Get the set of get_rele1_e2 from dev and test dataset.

Update:

self.rele1_e2: The set of rele1_e2.

get_tasks()[source]

Get entity/relation id, and entity/relation number.

Update:

self.ent2id: Entity to id. self.rel2id: Relation to id. self.symbol2id: Entity and relation to id. self.args.num_ent: Entity number. self.args.num_rel: Relation number.

class unKR.data.DataPreprocess.UKGData(args)[source]

Bases: object

Data preprocessing of ukg data.

args

Some pre-set parameters, such as dataset path, etc.

ent2id

Encoding the entity in triples, type: dict.

rel2id

Encoding the relation in triples, type: dict.

id2ent

Decoding the entity in triples, type: dict.

id2rel

Decoding the realtion in triples, type: dict.

train_triples

Record the triples for training, type: list.

valid_triples

Record the triples for validation, type: list.

test_triples

Record the triples for testing, type: list.

PSL_triples

Record the triples for softlogic, type: list. (will be used in UKGE_PSL)

pseudo_triples

Record the triples for pseudo, type: list. (will be used in UPGAT)

all_true_triples

Record all triples including train,valid and test, type: list.

hr2t_train

Record the tail corresponding to the same head and relation, type: defaultdict(class:set).

rt2h_train

Record the head corresponding to the same tail and relation, type: defaultdict(class:set).

h2rt_train

Record the tail, relation corresponding to the same head, type: defaultdict(class:set).

t2rh_train

Record the head, realtion corresponding to the same tail, type: defaultdict(class:set).

hr2t_total

Record the tail corresponding to the same head and relation in whole dataset(train + val + test), type: defaultdict(class:set).

rt2h_total

Record the head corresponding to the same tail and relation in whole dataset(train + val + test), type: defaultdict(class:set).

RatioOfPSL

Record the ratio of the number of PSL samples to the number of training samples. (will be used in UKGE_PSL)

pseudo_dataiter

Record the data in dataloader. (will be used in UPGAT)

static count_frequency(triples, start=4)[source]

Get frequency of a partial triple like (head, relation) or (relation, tail).

The frequency will be used for subsampling like word2vec.

Parameters:
  • triples – Sampled triples.

  • start – Initial count number.

Returns:

Record the number of (head, relation).

Return type:

count

get_h2rt_t2hr_from_train()[source]

Get the set of h2rt and t2hr from train dataset, the data type is numpy.

Update:

self.h2rt_train: The set of h2rt. self.t2rh_train: The set of t2hr.

get_hr2t_rt2h_from_total()[source]

Get the set of hr2t and rt2h from whole dataset(train+val+test), the data type is numpy.

Update:

self.hr2t_total: The set of hr2t in whole dataset. self.rt2h_total: The set of rt2h in whole dataset.

get_hr2t_rt2h_from_train()[source]

Get the set of hr2t and rt2h from train dataset, the data type is numpy.

Update:

self.hr2t_train: The set of hr2t. self.rt2h_train: The set of rt2h.

get_hr_train()[source]

Change the generation mode of batch. Merging triples which have same head and relation for 1vsN training mode.

update:

self.train_triples: The tuple(hr, t) list for training

get_id()[source]

Get entity/relation id, and entity/relation number.

Update:

self.ent2id: Entity to id. self.rel2id: Relation to id. self.id2ent: id to Entity. self.id2rel: id to Relation. self.args.num_ent: Entity number. self.args.num_rel: Relation number.

class unKR.data.DataPreprocess.UKGEBaseSampler(args)[source]

Bases: UKGData

data processing for UKG data, the sampling method is consistent with NeuralKG (https://github.com/zjukg/NeuralKG)

corrupt_head(t, r, num_max=1)[source]

Negative sampling of head entities.

Parameters:
  • t – Tail entity in triple.

  • r – Relation in triple.

  • num_max – The maximum of negative samples generated

Returns:

The negative sample of head entity filtering out the positive head entity.

Return type:

neg

corrupt_tail(h, r, num_max=1)[source]

Negative sampling of tail entities.

Parameters:
  • h – Head entity in triple.

  • r – Relation in triple.

  • num_max – The maximum of negative samples generated

Returns:

The negative sample of tail entity filtering out the positive tail entity.

Return type:

neg

get_PSL()[source]
get_all_true_triples()[source]
get_hr_map()[source]
get_test()[source]
get_train()[source]
get_valid()[source]
head_batch(h, r, t, neg_size=None)[source]

Negative sampling of head entities.

Parameters:
  • h – Head entity in triple

  • t – Tail entity in triple.

  • r – Relation in triple.

  • neg_size – The size of negative samples.

Returns:

The negative sample of head entity. [neg_size]

tail_batch(h, r, t, neg_size=None)[source]

Negative sampling of tail entities.

Parameters:
  • h – Head entity in triple

  • t – Tail entity in triple.

  • r – Relation in triple.

  • neg_size – The size of negative samples.

Returns:

The negative sample of tail entity. [neg_size]

Sampler

class unKR.data.Sampler.GMUCSampler(args)[source]

Bases: GMUCBaseSampler

GMUC sampling Process task-based data.

get_meta(left, right)[source]

get meta data

get_sampling_keys()[source]
sampling(data)[source]

Filtering out positive samples and selecting some samples randomly as negative samples.

Parameters:

data – A task/relation and all its triples.

Returns:

The training data.

Return type:

batch_data

class unKR.data.Sampler.GMUCTestSampler(sampler)[source]

Bases: object

Sampling triples and recording positive triples for testing.

sampler

The function of training sampler.

num_ent

The number of entities.

rel2candidates

Record the entities corresponding to the same relation, type: list.

symbol2id

Encoding the entity and relation in triples, type: dict.

ent2id

Encoding the entity in triples, type: dict.

get_sampling_keys()[source]
sampling(data)[source]

Sampling triples and recording positive triples for testing.

Parameters:

data – The triples used to be sampled.

Returns:

The data used to be evaluated.

Return type:

batch_data

class unKR.data.Sampler.KGDataset(triples)[source]

Bases: Dataset

class unKR.data.Sampler.UKGEPSLSampler(args)[source]

Bases: UKGEBaseSampler

Random negative sampling Filtering out positive samples and selecting some samples randomly as negative samples. UKGEPSLSampler is for UKGE_PSL

cross_sampling_flag

The flag of cross sampling head and tail negative samples.

get_sampling_keys()[source]
sampling(data)[source]

Filtering out positive samples and selecting some samples randomly as negative samples.

Parameters:

data – The triples used to be sampled.

Returns:

The training data.

Return type:

batch_data

uni_sampling(data)[source]
class unKR.data.Sampler.UKGETestSampler(sampler)[source]

Bases: object

Sampling triples and recording positive triples for testing. We offer two test sample methods: one is to use the same method as neuralkg to process the test set, and the other is to only use high confidence samples from the test set for testing

sampler

The function of training sampler.

hr2t_all

Record the tail corresponding to the same head and relation.

rt2h_all

Record the head corresponding to the same tail and relation.

hr2t_all_high_score

Record the tail corresponding to the same head and relation (only for high score samples in test).

rt2h_all_high_score

Record the head corresponding to the same tail and relation (only for high score samples in test).

num_ent

The count of entities.

construct_hr_map(data)[source]
get_hr2t_rt2h_from_all()[source]

Get the set of hr2t and rt2h from all datasets(train, valid, and test), the data type is tensor. Update:

self.hr2t_all: The set of hr2t. self.rt2h_all: The set of rt2h. self.hr2t_all_high_score: The set of hr2t (only for high score samples in test). self.rt2h_all_high_score: The set of rt2h (only for high score samples in test).

get_sampling_keys()[source]
sampling(data)[source]

Sampling triples and recording positive triples for testing.

Parameters:

data – The triples used to be sampled.

Returns:

The data used to be evaluated.

Return type:

batch_data

class unKR.data.Sampler.UKGEUniSampler(args)[source]

Bases: UKGEBaseSampler

Random negative sampling Filtering out positive samples and selecting some samples randomly as negative samples.

cross_sampling_flag

The flag of cross sampling head and tail negative samples.

get_sampling_keys()[source]
sampling(data)[source]

Filtering out positive samples and selecting some samples randomly as negative samples.

Parameters:

data – The triples used to be sampled.

Returns:

The training data.

Return type:

batch_data

uni_sampling(data)[source]
unKR.data.Sampler.normal(loc=0.0, scale=1.0, size=None)

Draw random samples from a normal (Gaussian) distribution.

The probability density function of the normal distribution, first derived by De Moivre and 200 years later by both Gauss and Laplace independently [2], is often called the bell curve because of its characteristic shape (see the example below).

The normal distributions occurs often in nature. For example, it describes the commonly occurring distribution of samples influenced by a large number of tiny, random disturbances, each with its own unique distribution [2].

Note

New code should use the ~numpy.random.Generator.normal method of a ~numpy.random.Generator instance instead; please see the random-quick-start.

Parameters:
  • loc (float or array_like of floats) – Mean (“centre”) of the distribution.

  • scale (float or array_like of floats) – Standard deviation (spread or “width”) of the distribution. Must be non-negative.

  • size (int or tuple of ints, optional) – Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. If size is None (default), a single value is returned if loc and scale are both scalars. Otherwise, np.broadcast(loc, scale).size samples are drawn.

Returns:

out – Drawn samples from the parameterized normal distribution.

Return type:

ndarray or scalar

See also

scipy.stats.norm

probability density function, distribution or cumulative density function, etc.

random.Generator.normal

which should be used for new code.

Notes

The probability density for the Gaussian distribution is

\[p(x) = \frac{1}{\sqrt{ 2 \pi \sigma^2 }} e^{ - \frac{ (x - \mu)^2 } {2 \sigma^2} },\]

where \(\mu\) is the mean and \(\sigma\) the standard deviation. The square of the standard deviation, \(\sigma^2\), is called the variance.

The function has its peak at the mean, and its “spread” increases with the standard deviation (the function reaches 0.607 times its maximum at \(x + \sigma\) and \(x - \sigma\) [2]). This implies that normal is more likely to return samples lying close to the mean, rather than those far away.

References

Examples

Draw samples from the distribution:

>>> mu, sigma = 0, 0.1 # mean and standard deviation
>>> s = np.random.normal(mu, sigma, 1000)

Verify the mean and the variance:

>>> abs(mu - np.mean(s))
0.0  # may vary
>>> abs(sigma - np.std(s, ddof=1))
0.1  # may vary

Display the histogram of the samples, along with the probability density function:

>>> import matplotlib.pyplot as plt
>>> count, bins, ignored = plt.hist(s, 30, density=True)
>>> plt.plot(bins, 1/(sigma * np.sqrt(2 * np.pi)) *
...                np.exp( - (bins - mu)**2 / (2 * sigma**2) ),
...          linewidth=2, color='r')
>>> plt.show()

Two-by-four array of samples from the normal distribution with mean 3 and standard deviation 2.5:

>>> np.random.normal(3, 2.5, size=(2, 4))
array([[-4.49401501,  4.00950034, -1.81814867,  7.29718677],   # random
       [ 0.39924804,  4.68456316,  4.99394529,  4.84057254]])  # random

KGDataModule

Base DataModule class.

class unKR.data.KGDataModule.GMUCDataModule(*args: Any, **kwargs: Any)[source]

Bases: BaseDataModule

GMUC Base DataModule. Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html

get_data_config()[source]

Return important settings of the dataset, which will be passed to instantiate models.

prepare_data()[source]

Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don’t set state self.x = y).

setup(stage=None)[source]

Split into train, val, test, and set dims. Should assign torch Dataset objects to self.data_train, self.data_val, and optionally self.data_test.

test_dataloader()[source]

Implement one or multiple PyTorch DataLoaders for testing.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a postive integer.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying testing samples.

Example:

def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def test_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Note

In the case where you return multiple test dataloaders, the test_step() will have an argument dataloader_idx which matches the order here.

train_dataloader()[source]

Implement one or more PyTorch DataLoaders for training.

Returns:

A collection of torch.utils.data.DataLoader specifying training samples. In the case of multiple dataloaders, please see this page.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example:

# single dataloader
def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=True
    )
    return loader

# multiple dataloaders, return as list
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a list of tensors: [batch_mnist, batch_cifar]
    return [mnist_loader, cifar_loader]

# multiple dataloader, return as dict
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
    return {'mnist': mnist_loader, 'cifar': cifar_loader}
val_dataloader()[source]

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying validation samples.

Examples:

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Note

In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.

class unKR.data.KGDataModule.KGDataModule(*args: Any, **kwargs: Any)[source]

Bases: BaseDataModule

Base DataModule. Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html

get_data_config()[source]

Return important settings of the dataset, which will be passed to instantiate models.

get_train_bs()[source]

Get batch size for training.

If the num_batches isn`t zero, it will divide data_train by num_batches to get batch size. And if user don`t give batch size and num_batches=0, it will raise ValueError.

Returns:

The batch size for training.

Return type:

self.args.train_bs

prepare_data()[source]

Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don’t set state self.x = y).

setup(stage=None)[source]

Split into train, val, test, and set dims. Should assign torch Dataset objects to self.data_train, self.data_val, and optionally self.data_test.

test_dataloader()[source]

Implement one or multiple PyTorch DataLoaders for testing.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a postive integer.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying testing samples.

Example:

def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def test_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Note

In the case where you return multiple test dataloaders, the test_step() will have an argument dataloader_idx which matches the order here.

train_dataloader()[source]

Implement one or more PyTorch DataLoaders for training.

Returns:

A collection of torch.utils.data.DataLoader specifying training samples. In the case of multiple dataloaders, please see this page.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example:

# single dataloader
def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=True
    )
    return loader

# multiple dataloaders, return as list
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a list of tensors: [batch_mnist, batch_cifar]
    return [mnist_loader, cifar_loader]

# multiple dataloader, return as dict
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
    return {'mnist': mnist_loader, 'cifar': cifar_loader}
val_dataloader()[source]

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying validation samples.

Examples:

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Note

In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.