cayley
October 31, 2024, 1:16am
1
There two transformers in the vision encoder. One is called global_transformer and the other transformer.
I see is_gated is different. What is global about the ‘global_transformer’?
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
bias=False,
)
self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config)
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
# layer norms
self.layernorm_pre = nn.LayerNorm(self.hidden_size)
self.layernorm_post = nn.LayerNorm(self.hidden_size)
# encoders
self.transformer = MllamaVisionEncoder(config, config.num_hidden_layers, is_gated=False)
self.global_transformer = MllamaVisionEncoder(config, config.num_global_layers, is_gated=True)
self.transformer = MllamaVisionEncoder(config, config.num_hidden_layers, is_gated=False)
self.global_transformer = MllamaVisionEncoder(config, config.num_global_layers, is_gated=True)
Thanks.
1 Like
It seems to be turned on by output_hidden_states and/or output_attentions .
batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1
)
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, -1
)
# Concatenate final hidden state and intermediate hidden states
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
if output_hidden_states:
hidden_states = tuple(all_intermediate_hidden_states) + tuple(global_output[1])
else:
hidden_states = None
if output_attentions:
# global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range
global_attn = tuple(global_output[2]) if output_hidden_states else tuple(global_output[1])
attentions = tuple(output[2]) + global_attn
else:
attentions = None
)
# Concatenate final hidden state and intermediate hidden states
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
if output_hidden_states:
hidden_states = tuple(all_intermediate_hidden_states) + tuple(global_output[1])
else:
hidden_states = None
if output_attentions:
# global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range
global_attn = tuple(global_output[2]) if output_hidden_states else tuple(global_output[1])
attentions = tuple(output[2]) + global_attn
else:
attentions = None
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None)
return BaseModelOutput(