🔬 Exploring Reinforcement Learning for Molecule Generation with GPT-Based Models; Loss Fluctuations

Hey everyone! :wave: I’m diving into the world of Reinforcement Learning (RL) to enhance the performance of a GPT-based model for generating valid smiles strings, representing molecules. However, as a newcomer to RL and GPT, I’m facing uncertainties about my approach.

:mag: Current Approach Overview:

  • Model: GPT-based model, fine-tuned for molecule generation.
  • RL Technique: Implementing policy gradient RL with a custom reward system.
  • Reward System: Provides rewards ranging from 0 to positive values based on the generated molecules’ validity and quality.
  • Optimizer: Using PyTorchAdamW optimizer.
  • Hyperparameters: Batch size, learning rate scheduler, gamma value for discounted rewards.

:question: Questions for Discussion:

  1. Valid RL Technique? Is policy gradient RL suitable for molecule generation tasks like mine?
  2. Optimizers and Schedulers: Should I explore different optimizers or learning rate schedulers to enhance training stability and convergence speed?
  3. Reward Normalization: Is it necessary to normalize rewards to ensure effective RL training?
  4. Batch Size Adjustment: Would changing the batch size impact training efficiency or quality of generated molecules?
  5. Training Duration: How many epochs should I run to evaluate if my RL approach is effective or needs adjustments?

:speech_balloon: Code and Context: I’ve shared my current implementation of the RL training loop using PyTorch. Would appreciate insights, suggestions, or tips from experienced practitioners in RL or GPT-based modeling!

class Reinforcement(object):
    def __init__(self, model, generator, get_reward, optimizer):

        super(Reinforcement, self).__init__()
        self.generator = generator
        self.get_reward = get_reward
        self.optimizer = optimizer
        self.model = model


    def policy_gradient(self, n_batch, gamma=0.97, grad_clipping=None):

        rl_loss  = torch.tensor(0.0, requires_grad=True)
        total_reward = 0
        self.optimizer.zero_grad()

        for _ in range(n_batch):
            # Sampling new trajectory
            reward = 0
            number_of_attemps = 0

            input_= 'proper input'
            trajectory = self.generator(input_, num_generated=10, batch_generated_size=10)
            while reward == 0:
                random_number = random.randint(0, 9)
                number_of_attemps += 1
                trajectory_ = [str(trajectory[random_number])]
                reward = self.get_reward(trajectory_)


            
            discounted_reward = reward
            total_reward += reward
            total_reward = total_reward / number_of_attemps

            # Converting string of characters into tensor
            tokenizer = AutoTokenizer.from_pretrained('my tokenizer')
            trajectory_input = tokenizer(trajectory_[0], return_tensors='pt')['input_ids']
            for p in range(trajectory_input.shape[1]-1):
              # Get logits from the model

              with torch.no_grad():
                outputs = model(input_ids=trajectory_input[:, p:p+1])

              logits = outputs.logits

              # Apply softmax and compute log probabilities
              log_probs = F.log_softmax(logits, dim=-1)

              # Get log probability of the next token
              top_i = trajectory_input[:,p+1]
              rl_loss = rl_loss - (log_probs[:, 0, top_i]* discounted_reward)/100
              discounted_reward = discounted_reward * gamma
              # print("Updating rl_loss by:", -(log_probs[:, 0, top_i]*discounted_reward) /100)



            # Doing backward pass and parameters update
        rl_loss = rl_loss / n_batch
        total_reward = total_reward / n_batch
        if grad_clipping is not None:
          torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                         grad_clipping)
        self.optimizer.zero_grad()
        rl_loss.backward()
        self.optimizer.step()


        return total_reward, rl_loss.item()

Looking forward to your valuable feedback and advice on refining my approach! :rocket:

1 Like

Forgive my naivety towards RL, but what are you actually training here?

Are you just using GPT for inference? or is the goal to train GPT to give better generations?

BTW careful with using “input” as a variable as this word is reserved in python
Python input() Function (w3schools.com)

I am wondering if this will be useful:
TRL - Transformer Reinforcement Learning (huggingface.co)

1 Like

Thanks for your response! I’m actually utilizing Reinforcement Learning (RL) to fine-tune the GPT-based model for generating better molecule representations.

Indeed, using “input” as a variable name can be confusing due to its reserved status in Python. I’ll make sure to adjust that for clarity in my code.

Thanks again for your insights and suggestions! :pray: