Answer to: Implementing a LoRA wrapper for Conv2D in Tensorflow
Score: 0 • Accepted
A first approach that ensures gradients always flow through the LoRA matrices, and that use a training argument to control behavior at runtime:
def call(self, inputs, training=None):
# Compute LoRA:
delta_W = tf.matmul(self.lora_B, self.lora_A)
delta_W = tf.reshape(delta_W, self.kernel_shape)
# Apply scaling based on training:
if training and self.lora_enabled and not self.merged:
effective_delta = self.scaling * delta_W
else:
effective_delta = tf.zeros_like(delta_W)
outputs = tf.nn.conv2d(
inputs,
self.__original_layer.kernel + effective_delta,
strides=self.__original_layer.strides,
padding=self.__original_layer.padding.upper()
)
if self.__original_layer.use_bias:
outputs = tf.nn.bias_add(outputs, self.__original_layer.bias)
if self.__original_layer.activation is not None:
outputs = self.__original_layer.activation(outputs)
return outputs
Another approach is to keep the LoRA matrices always trainable, but mask their contribution:
def build(self, input_shape):
# ... existing code ...
# Always trainable, control via scaling
self.lora_A = self.add_weight(
name="LoRA_matA",
shape=lora_A_shape,
initializer=keras.initializers.HeUniform(),
trainable=True, # Always True
)
self.lora_B = self.add_weight(
name="LoRA_matB",
shape=lora_B_shape,
initializer=keras.initializers.Zeros(),
trainable=True, # Always True
)
# Control via this flag
self.lora_active = tf.Variable(
initial_value=self.lora_enabled,
trainable=False,
dtype=tf.bool
)
def call(self, inputs):
delta_W = tf.matmul(self.lora_B, self.lora_A)
delta_W = tf.reshape(delta_W, self.kernel_shape)
# Use tf.cond for runtime switching
effective_scaling = tf.cond(
self.lora_active,
lambda: self.scaling,
lambda: 0.0
)
#...
View Question ↗
Question
Parent Entity
Score: 5 • Views: 213
Site: stackoverflow
SaaS Metrics